示例#1
0
def plot_extinction(tag: str,
                    burnin: Optional[int] = None,
                    random: Optional[int] = None,
                    wavel_range: Optional[Tuple[float, float]] = None,
                    xlim: Optional[Tuple[float, float]] = None,
                    ylim: Optional[Tuple[float, float]] = None,
                    offset: Optional[Tuple[float, float]] = None,
                    output: str = 'extinction.pdf') -> None:
    """
    Function to plot random samples of the extinction, either from fitting a size distribution
    of enstatite grains (``dust_radius``, ``dust_sigma``, and ``dust_ext``), or from fitting
    ISM extinction (``ism_ext`` and optionally ``ism_red``).

    Parameters
    ----------
    tag : str
        Database tag with the samples.
    burnin : int, None
        Number of burnin steps to exclude. All samples are used if set to ``None``. Only required
        after running MCMC with :func:`~species.analysis.fit_model.FitModel.run_mcmc`.
    random : int, None
        Number of randomly selected samples. All samples are used if set to ``None``.
    wavel_range : tuple(float, float), None
        Wavelength range (um) for the extinction. The default wavelength range (0.4, 10.) is used
        if set to ``None``.
    xlim : tuple(float, float), None
        Limits of the wavelength axis. The range is set automatically if set to ``None``.
    ylim : tuple(float, float)
        Limits of the extinction axis. The range is set automatically if set to ``None``.
    offset : tuple(float, float), None
        Offset of the x- and y-axis label. Default values are used if set to ``None``.
    output : str
        Output filename.

    Returns
    -------
    NoneType
        None
    """

    if burnin is None:
        burnin = 0

    if wavel_range is None:
        wavel_range = (0.4, 10.)

    mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
    mpl.rcParams['font.family'] = 'serif'

    plt.rc('axes', edgecolor='black', linewidth=2.2)

    species_db = database.Database()
    box = species_db.get_samples(tag)

    samples = box.samples

    if samples.ndim == 2 and random is not None:
        ran_index = np.random.randint(samples.shape[0], size=random)
        samples = samples[ran_index, ]

    elif samples.ndim == 3:
        if burnin > samples.shape[1]:
            raise ValueError(
                f'The \'burnin\' value is larger than the number of steps '
                f'({samples.shape[1]}) that are made by the walkers.')

        samples = samples[:, burnin:, :]

        ran_walker = np.random.randint(samples.shape[0], size=random)
        ran_step = np.random.randint(samples.shape[1], size=random)
        samples = samples[ran_walker, ran_step, :]

    plt.figure(1, figsize=(6, 3))
    gridsp = mpl.gridspec.GridSpec(1, 1)
    gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)

    ax = plt.subplot(gridsp[0, 0])

    ax.tick_params(axis='both',
                   which='major',
                   colors='black',
                   labelcolor='black',
                   direction='in',
                   width=1,
                   length=5,
                   labelsize=12,
                   top=True,
                   bottom=True,
                   left=True,
                   right=True,
                   labelbottom=True)

    ax.tick_params(axis='both',
                   which='minor',
                   colors='black',
                   labelcolor='black',
                   direction='in',
                   width=1,
                   length=3,
                   labelsize=12,
                   top=True,
                   bottom=True,
                   left=True,
                   right=True,
                   labelbottom=True)

    ax.set_xlabel('Wavelength (µm)', fontsize=12)
    ax.set_ylabel('Extinction (mag)', fontsize=12)

    if xlim is not None:
        ax.set_xlim(xlim[0], xlim[1])

    if ylim is not None:
        ax.set_ylim(ylim[0], ylim[1])

    if offset is not None:
        ax.get_xaxis().set_label_coords(0.5, offset[0])
        ax.get_yaxis().set_label_coords(offset[1], 0.5)

    else:
        ax.get_xaxis().set_label_coords(0.5, -0.22)
        ax.get_yaxis().set_label_coords(-0.09, 0.5)

    sample_wavel = np.linspace(wavel_range[0], wavel_range[1], 100)

    if 'lognorm_radius' in box.parameters and 'lognorm_sigma' in box.parameters and \
            'lognorm_ext' in box.parameters:

        cross_optical, dust_radius, dust_sigma = dust_util.interp_lognorm([],
                                                                          [],
                                                                          None)

        log_r_index = box.parameters.index('lognorm_radius')
        sigma_index = box.parameters.index('lognorm_sigma')
        ext_index = box.parameters.index('lognorm_ext')

        log_r_g = samples[:, log_r_index]
        sigma_g = samples[:, sigma_index]
        dust_ext = samples[:, ext_index]

        database_path = dust_util.check_dust_database()

        with h5py.File(database_path, 'r') as h5_file:
            cross_section = np.asarray(
                h5_file['dust/lognorm/mgsio3/crystalline/cross_section'])
            wavelength = np.asarray(
                h5_file['dust/lognorm/mgsio3/crystalline/wavelength'])

        cross_interp = RegularGridInterpolator(
            (wavelength, dust_radius, dust_sigma), cross_section)

        for i in range(samples.shape[0]):
            cross_tmp = cross_optical['Generic/Bessell.V'](sigma_g[i],
                                                           10.**log_r_g[i])

            n_grains = dust_ext[i] / cross_tmp / 2.5 / np.log10(np.exp(1.))

            sample_cross = np.zeros(sample_wavel.shape)

            for j, item in enumerate(sample_wavel):
                sample_cross[j] = cross_interp(
                    (item, 10.**log_r_g[i], sigma_g[i]))

            sample_ext = 2.5 * np.log10(np.exp(1.)) * sample_cross * n_grains

            ax.plot(sample_wavel,
                    sample_ext,
                    ls='-',
                    lw=0.5,
                    color='black',
                    alpha=0.5)

    elif 'powerlaw_max' in box.parameters and 'powerlaw_exp' in box.parameters and \
            'powerlaw_ext' in box.parameters:

        cross_optical, dust_max, dust_exp = dust_util.interp_powerlaw([], [],
                                                                      None)

        r_max_index = box.parameters.index('powerlaw_max')
        exp_index = box.parameters.index('powerlaw_exp')
        ext_index = box.parameters.index('powerlaw_ext')

        r_max = samples[:, r_max_index]
        exponent = samples[:, exp_index]
        dust_ext = samples[:, ext_index]

        database_path = dust_util.check_dust_database()

        with h5py.File(database_path, 'r') as h5_file:
            cross_section = np.asarray(
                h5_file['dust/powerlaw/mgsio3/crystalline/cross_section'])
            wavelength = np.asarray(
                h5_file['dust/powerlaw/mgsio3/crystalline/wavelength'])

        cross_interp = RegularGridInterpolator(
            (wavelength, dust_max, dust_exp), cross_section)

        for i in range(samples.shape[0]):
            cross_tmp = cross_optical['Generic/Bessell.V'](exponent[i],
                                                           10.**r_max[i])

            n_grains = dust_ext[i] / cross_tmp / 2.5 / np.log10(np.exp(1.))

            sample_cross = np.zeros(sample_wavel.shape)

            for j, item in enumerate(sample_wavel):
                sample_cross[j] = cross_interp(
                    (item, 10.**r_max[i], exponent[i]))

            sample_ext = 2.5 * np.log10(np.exp(1.)) * sample_cross * n_grains

            ax.plot(sample_wavel,
                    sample_ext,
                    ls='-',
                    lw=0.5,
                    color='black',
                    alpha=0.5)

    elif 'ism_ext' in box.parameters:

        ext_index = box.parameters.index('ism_ext')
        ism_ext = samples[:, ext_index]

        if 'ism_red' in box.parameters:
            red_index = box.parameters.index('ism_red')
            ism_red = samples[:, red_index]

        else:
            ism_red = np.full(samples.shape[0], 3.1)

        for i in range(samples.shape[0]):
            sample_ext = dust_util.ism_extinction(ism_ext[i], ism_red[i],
                                                  sample_wavel)

            ax.plot(sample_wavel,
                    sample_ext,
                    ls='-',
                    lw=0.5,
                    color='black',
                    alpha=0.5)

    else:
        raise ValueError(
            'The SamplesBox does not contain extinction parameters.')

    print(f'Plotting extinction: {output}...', end='', flush=True)

    plt.savefig(os.getcwd() + '/' + output, bbox_inches='tight')
    plt.clf()
    plt.close()

    print(' [DONE]')
示例#2
0
    def apply_powerlaw_ext(wavelength: np.ndarray,
                           flux: np.ndarray,
                           r_max_interp: float,
                           exp_interp: float,
                           v_band_ext: float) -> np.ndarray:
        """
        Internal function for applying extinction by dust to a spectrum.

        wavelength : np.ndarray
            Wavelengths (um) of the spectrum.
        flux : np.ndarray
            Fluxes (W m-2 um-1) of the spectrum.
        r_max_interp : float
            Maximum radius (um) of the power-law size distribution.
        exp_interp : float
            Exponent of the power-law size distribution.
        v_band_ext : float
            The extinction (mag) in the V band.

        Returns
        -------
        np.ndarray
            Fluxes (W m-2 um-1) with the extinction applied.
        """

        database_path = dust_util.check_dust_database()

        with h5py.File(database_path, 'r') as h5_file:
            dust_cross = np.asarray(h5_file['dust/powerlaw/mgsio3/crystalline/cross_section'])
            dust_wavel = np.asarray(h5_file['dust/powerlaw/mgsio3/crystalline/wavelength'])
            dust_r_max = np.asarray(h5_file['dust/powerlaw/mgsio3/crystalline/radius_max'])
            dust_exp = np.asarray(h5_file['dust/powerlaw/mgsio3/crystalline/exponent'])

        dust_interp = RegularGridInterpolator((dust_wavel, dust_r_max, dust_exp),
                                              dust_cross,
                                              method='linear',
                                              bounds_error=True)

        read_filt = read_filter.ReadFilter('Generic/Bessell.V')
        filt_trans = read_filt.get_filter()

        cross_phot = np.zeros((dust_r_max.shape[0], dust_exp.shape[0]))

        for i in range(dust_r_max.shape[0]):
            for j in range(dust_exp.shape[0]):
                cross_interp = interp1d(dust_wavel,
                                        dust_cross[:, i, j],
                                        kind='linear',
                                        bounds_error=True)

                cross_tmp = cross_interp(filt_trans[:, 0])

                integral1 = np.trapz(filt_trans[:, 1]*cross_tmp, filt_trans[:, 0])
                integral2 = np.trapz(filt_trans[:, 1], filt_trans[:, 0])

                # Filter-weighted average of the extinction cross section
                cross_phot[i, j] = integral1/integral2

        cross_interp = interp2d(dust_exp,
                                dust_r_max,
                                cross_phot,
                                kind='linear',
                                bounds_error=True)

        cross_v_band = cross_interp(exp_interp, 10.**r_max_interp)[0]

        r_max_full = np.full(wavelength.shape[0], 10.**r_max_interp)
        exp_full = np.full(wavelength.shape[0], exp_interp)

        cross_new = dust_interp(np.column_stack((wavelength, r_max_full, exp_full)))

        n_grains = v_band_ext / cross_v_band / 2.5 / np.log10(np.exp(1.))

        return flux * np.exp(-cross_new*n_grains)
示例#3
0
def plot_extinction(
    tag: str,
    burnin: Optional[int] = None,
    random: Optional[int] = None,
    wavel_range: Optional[Tuple[float, float]] = None,
    xlim: Optional[Tuple[float, float]] = None,
    ylim: Optional[Tuple[float, float]] = None,
    offset: Optional[Tuple[float, float]] = None,
    output: Optional[str] = "extinction.pdf",
) -> None:
    """
    Function to plot random samples of the extinction, either from fitting a size distribution
    of enstatite grains (``dust_radius``, ``dust_sigma``, and ``dust_ext``), or from fitting
    ISM extinction (``ism_ext`` and optionally ``ism_red``).

    Parameters
    ----------
    tag : str
        Database tag with the samples.
    burnin : int, None
        Number of burnin steps to exclude. All samples are used if set to ``None``. Only required
        after running MCMC with :func:`~species.analysis.fit_model.FitModel.run_mcmc`.
    random : int, None
        Number of randomly selected samples. All samples are used if set to ``None``.
    wavel_range : tuple(float, float), None
        Wavelength range (um) for the extinction. The default wavelength range (0.4, 10.) is used
        if set to ``None``.
    xlim : tuple(float, float), None
        Limits of the wavelength axis. The range is set automatically if set to ``None``.
    ylim : tuple(float, float)
        Limits of the extinction axis. The range is set automatically if set to ``None``.
    offset : tuple(float, float), None
        Offset of the x- and y-axis label. Default values are used if set to ``None``.
    output : str
        Output filename for the plot. The plot is shown in an
        interface window if the argument is set to ``None``.

    Returns
    -------
    NoneType
        None
    """

    if burnin is None:
        burnin = 0

    if wavel_range is None:
        wavel_range = (0.4, 10.0)

    mpl.rcParams["font.serif"] = ["Bitstream Vera Serif"]
    mpl.rcParams["font.family"] = "serif"

    plt.rc("axes", edgecolor="black", linewidth=2.2)

    species_db = database.Database()
    box = species_db.get_samples(tag)

    samples = box.samples

    if samples.ndim == 2 and random is not None:
        ran_index = np.random.randint(samples.shape[0], size=random)
        samples = samples[ran_index, ]

    elif samples.ndim == 3:
        if burnin > samples.shape[1]:
            raise ValueError(
                f"The 'burnin' value is larger than the number of steps "
                f"({samples.shape[1]}) that are made by the walkers.")

        samples = samples[:, burnin:, :]

        ran_walker = np.random.randint(samples.shape[0], size=random)
        ran_step = np.random.randint(samples.shape[1], size=random)
        samples = samples[ran_walker, ran_step, :]

    plt.figure(1, figsize=(6, 3))
    gridsp = mpl.gridspec.GridSpec(1, 1)
    gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)

    ax = plt.subplot(gridsp[0, 0])

    ax.tick_params(
        axis="both",
        which="major",
        colors="black",
        labelcolor="black",
        direction="in",
        width=1,
        length=5,
        labelsize=12,
        top=True,
        bottom=True,
        left=True,
        right=True,
        labelbottom=True,
    )

    ax.tick_params(
        axis="both",
        which="minor",
        colors="black",
        labelcolor="black",
        direction="in",
        width=1,
        length=3,
        labelsize=12,
        top=True,
        bottom=True,
        left=True,
        right=True,
        labelbottom=True,
    )

    ax.set_xlabel("Wavelength (µm)", fontsize=12)
    ax.set_ylabel("Extinction (mag)", fontsize=12)

    if xlim is not None:
        ax.set_xlim(xlim[0], xlim[1])

    if ylim is not None:
        ax.set_ylim(ylim[0], ylim[1])

    if offset is not None:
        ax.get_xaxis().set_label_coords(0.5, offset[0])
        ax.get_yaxis().set_label_coords(offset[1], 0.5)

    else:
        ax.get_xaxis().set_label_coords(0.5, -0.22)
        ax.get_yaxis().set_label_coords(-0.09, 0.5)

    sample_wavel = np.linspace(wavel_range[0], wavel_range[1], 100)

    if ("lognorm_radius" in box.parameters
            and "lognorm_sigma" in box.parameters
            and "lognorm_ext" in box.parameters):

        cross_optical, dust_radius, dust_sigma = dust_util.interp_lognorm([],
                                                                          [],
                                                                          None)

        log_r_index = box.parameters.index("lognorm_radius")
        sigma_index = box.parameters.index("lognorm_sigma")
        ext_index = box.parameters.index("lognorm_ext")

        log_r_g = samples[:, log_r_index]
        sigma_g = samples[:, sigma_index]
        dust_ext = samples[:, ext_index]

        database_path = dust_util.check_dust_database()

        with h5py.File(database_path, "r") as h5_file:
            cross_section = np.asarray(
                h5_file["dust/lognorm/mgsio3/crystalline/cross_section"])
            wavelength = np.asarray(
                h5_file["dust/lognorm/mgsio3/crystalline/wavelength"])

        cross_interp = RegularGridInterpolator(
            (wavelength, dust_radius, dust_sigma), cross_section)

        for i in range(samples.shape[0]):
            cross_tmp = cross_optical["Generic/Bessell.V"](sigma_g[i],
                                                           10.0**log_r_g[i])

            n_grains = dust_ext[i] / cross_tmp / 2.5 / np.log10(np.exp(1.0))

            sample_cross = np.zeros(sample_wavel.shape)

            for j, item in enumerate(sample_wavel):
                sample_cross[j] = cross_interp(
                    (item, 10.0**log_r_g[i], sigma_g[i]))

            sample_ext = 2.5 * np.log10(np.exp(1.0)) * sample_cross * n_grains

            ax.plot(sample_wavel,
                    sample_ext,
                    ls="-",
                    lw=0.5,
                    color="black",
                    alpha=0.5)

    elif ("powerlaw_max" in box.parameters and "powerlaw_exp" in box.parameters
          and "powerlaw_ext" in box.parameters):

        cross_optical, dust_max, dust_exp = dust_util.interp_powerlaw([], [],
                                                                      None)

        r_max_index = box.parameters.index("powerlaw_max")
        exp_index = box.parameters.index("powerlaw_exp")
        ext_index = box.parameters.index("powerlaw_ext")

        r_max = samples[:, r_max_index]
        exponent = samples[:, exp_index]
        dust_ext = samples[:, ext_index]

        database_path = dust_util.check_dust_database()

        with h5py.File(database_path, "r") as h5_file:
            cross_section = np.asarray(
                h5_file["dust/powerlaw/mgsio3/crystalline/cross_section"])
            wavelength = np.asarray(
                h5_file["dust/powerlaw/mgsio3/crystalline/wavelength"])

        cross_interp = RegularGridInterpolator(
            (wavelength, dust_max, dust_exp), cross_section)

        for i in range(samples.shape[0]):
            cross_tmp = cross_optical["Generic/Bessell.V"](exponent[i],
                                                           10.0**r_max[i])

            n_grains = dust_ext[i] / cross_tmp / 2.5 / np.log10(np.exp(1.0))

            sample_cross = np.zeros(sample_wavel.shape)

            for j, item in enumerate(sample_wavel):
                sample_cross[j] = cross_interp(
                    (item, 10.0**r_max[i], exponent[i]))

            sample_ext = 2.5 * np.log10(np.exp(1.0)) * sample_cross * n_grains

            ax.plot(sample_wavel,
                    sample_ext,
                    ls="-",
                    lw=0.5,
                    color="black",
                    alpha=0.5)

    elif "ism_ext" in box.parameters:

        ext_index = box.parameters.index("ism_ext")
        ism_ext = samples[:, ext_index]

        if "ism_red" in box.parameters:
            red_index = box.parameters.index("ism_red")
            ism_red = samples[:, red_index]

        else:
            # Use default ISM redenning (R_V = 3.1) if ism_red was not fitted
            ism_red = np.full(samples.shape[0], 3.1)

        for i in range(samples.shape[0]):
            sample_ext = dust_util.ism_extinction(ism_ext[i], ism_red[i],
                                                  sample_wavel)

            ax.plot(sample_wavel,
                    sample_ext,
                    ls="-",
                    lw=0.5,
                    color="black",
                    alpha=0.5)

    else:
        raise ValueError(
            "The SamplesBox does not contain extinction parameters.")

    if output is None:
        print("Plotting extinction...", end="", flush=True)
    else:
        print(f"Plotting extinction: {output}...", end="", flush=True)

    print(" [DONE]")

    if output is None:
        plt.show()
    else:
        plt.savefig(output, bbox_inches="tight")

    plt.clf()
    plt.close()
示例#4
0
    def apply_lognorm_ext(wavelength: np.ndarray,
                          flux: np.ndarray,
                          radius_interp: float,
                          sigma_interp: float,
                          v_band_ext: float) -> np.ndarray:
        """
        Internal function for applying extinction by dust to a spectrum.

        wavelength : np.ndarray
            Wavelengths (um) of the spectrum.
        flux : np.ndarray
            Fluxes (W m-2 um-1) of the spectrum.
        radius_interp : float
            Logarithm of the mean geometric radius (um) of the log-normal size distribution.
        sigma_interp : float
            Geometric standard deviation (dimensionless) of the log-normal size distribution.
        v_band_ext : float
            The extinction (mag) in the V band.

        Returns
        -------
        np.ndarray
            Fluxes (W m-2 um-1) with the extinction applied.
        """

        database_path = dust_util.check_dust_database()

        with h5py.File(database_path, 'r') as h5_file:
            dust_cross = np.asarray(h5_file['dust/lognorm/mgsio3/crystalline/cross_section'])
            dust_wavel = np.asarray(h5_file['dust/lognorm/mgsio3/crystalline/wavelength'])
            dust_radius = np.asarray(h5_file['dust/lognorm/mgsio3/crystalline/radius_g'])
            dust_sigma = np.asarray(h5_file['dust/lognorm/mgsio3/crystalline/sigma_g'])

        dust_interp = RegularGridInterpolator((dust_wavel, dust_radius, dust_sigma),
                                              dust_cross,
                                              method='linear',
                                              bounds_error=True)

        read_filt = read_filter.ReadFilter('Generic/Bessell.V')
        filt_trans = read_filt.get_filter()

        cross_phot = np.zeros((dust_radius.shape[0], dust_sigma.shape[0]))

        for i in range(dust_radius.shape[0]):
            for j in range(dust_sigma.shape[0]):
                cross_interp = interp1d(dust_wavel,
                                        dust_cross[:, i, j],
                                        kind='linear',
                                        bounds_error=True)

                cross_tmp = cross_interp(filt_trans[:, 0])

                integral1 = np.trapz(filt_trans[:, 1]*cross_tmp, filt_trans[:, 0])
                integral2 = np.trapz(filt_trans[:, 1], filt_trans[:, 0])

                # Filter-weighted average of the extinction cross section
                cross_phot[i, j] = integral1/integral2

        cross_interp = interp2d(dust_sigma,
                                dust_radius,
                                cross_phot,
                                kind='linear',
                                bounds_error=True)

        cross_v_band = cross_interp(sigma_interp, 10.**radius_interp)[0]

        radius_full = np.full(wavelength.shape[0], 10.**radius_interp)
        sigma_full = np.full(wavelength.shape[0], sigma_interp)

        cross_new = dust_interp(np.column_stack((wavelength, radius_full, sigma_full)))

        n_grains = v_band_ext / cross_v_band / 2.5 / np.log10(np.exp(1.))

        return flux * np.exp(-cross_new*n_grains)