예제 #1
0
def plot_walkers(tag: str,
                 nsteps: Optional[int] = None,
                 offset: Optional[Tuple[float, float]] = None,
                 output: str = 'walkers.pdf') -> None:
    """
    Function to plot the step history of the walkers.

    Parameters
    ----------
    tag : str
        Database tag with the samples.
    nsteps : int, None
        Number of steps that are plotted. All steps are plotted if set to ``None``.
    offset : tuple(float, float), None
        Offset of the x- and y-axis label. Default values are used if set to ``None``.
    output : str
        Output filename.

    Returns
    -------
    NoneType
        None
    """

    print(f'Plotting walkers: {output}...', end='', flush=True)

    mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
    mpl.rcParams['font.family'] = 'serif'

    plt.rc('axes', edgecolor='black', linewidth=2.2)

    species_db = database.Database()
    box = species_db.get_samples(tag)

    samples = box.samples
    labels = plot_util.update_labels(box.parameters)

    if samples.ndim == 2:
        raise ValueError(
            f'The samples of \'{tag}\' have only 2 dimensions whereas 3 are required '
            f'for plotting the walkers. The plot_walkers function can only be '
            f'used after running the MCMC with run_mcmc and not after running '
            f'MultiNest with run_multinest.')

    ndim = samples.shape[-1]

    plt.figure(1, figsize=(6, ndim * 1.5))
    gridsp = mpl.gridspec.GridSpec(ndim, 1)
    gridsp.update(wspace=0, hspace=0.1, left=0, right=1, bottom=0, top=1)

    for i in range(ndim):
        ax = plt.subplot(gridsp[i, 0])

        if i == ndim - 1:
            ax.tick_params(axis='both',
                           which='major',
                           colors='black',
                           labelcolor='black',
                           direction='in',
                           width=1,
                           length=5,
                           labelsize=12,
                           top=True,
                           bottom=True,
                           left=True,
                           right=True,
                           labelbottom=True)

            ax.tick_params(axis='both',
                           which='minor',
                           colors='black',
                           labelcolor='black',
                           direction='in',
                           width=1,
                           length=3,
                           labelsize=12,
                           top=True,
                           bottom=True,
                           left=True,
                           right=True,
                           labelbottom=True)

        else:
            ax.tick_params(axis='both',
                           which='major',
                           colors='black',
                           labelcolor='black',
                           direction='in',
                           width=1,
                           length=5,
                           labelsize=12,
                           top=True,
                           bottom=True,
                           left=True,
                           right=True,
                           labelbottom=False)

            ax.tick_params(axis='both',
                           which='minor',
                           colors='black',
                           labelcolor='black',
                           direction='in',
                           width=1,
                           length=3,
                           labelsize=12,
                           top=True,
                           bottom=True,
                           left=True,
                           right=True,
                           labelbottom=False)

        if i == ndim - 1:
            ax.set_xlabel('Step number', fontsize=10)
        else:
            ax.set_xlabel('', fontsize=10)

        ax.set_ylabel(labels[i], fontsize=10)

        if offset is not None:
            ax.get_xaxis().set_label_coords(0.5, offset[0])
            ax.get_yaxis().set_label_coords(offset[1], 0.5)

        else:
            ax.get_xaxis().set_label_coords(0.5, -0.22)
            ax.get_yaxis().set_label_coords(-0.09, 0.5)

        if nsteps is not None:
            ax.set_xlim(0, nsteps)

        for j in range(samples.shape[0]):
            ax.plot(samples[j, :, i], ls='-', lw=0.5, color='black', alpha=0.5)

    plt.savefig(os.getcwd() + '/' + output, bbox_inches='tight')
    plt.clf()
    plt.close()

    print(' [DONE]')
예제 #2
0
def plot_posterior(tag: str,
                   burnin: Optional[int] = None,
                   title: Optional[str] = None,
                   offset: Optional[Tuple[float, float]] = None,
                   title_fmt: Union[str, List[str]] = '.2f',
                   limits: Optional[List[Tuple[float, float]]] = None,
                   max_posterior: bool = False,
                   inc_luminosity: bool = False,
                   inc_mass: bool = False,
                   output: str = 'posterior.pdf') -> None:
    """
    Function to plot the posterior distribution.

    Parameters
    ----------
    tag : str
        Database tag with the samples.
    burnin : int, None
        Number of burnin steps to exclude. All samples are used if set to ``None``.
    title : str, None
        Plot title. No title is shown if set to ``None``.
    offset : tuple(float, float), None
        Offset of the x- and y-axis label. Default values are used if set to ``None``.
    title_fmt : str, list(str)
        Format of the titles above the 1D distributions. Either a single string, which will be used
        for all parameters, or a list with the title format for each parameter separately (in the
        order as shown in the corner plot).
    limits : list(tuple(float, float), ), None
        Axis limits of all parameters. Automatically set if set to ``None``.
    max_posterior : bool
        Plot the position of the sample with the maximum posterior probability.
    inc_luminosity : bool
        Include the log10 of the luminosity in the posterior plot as calculated from the
        effective temperature and radius.
    inc_mass : bool
        Include the mass in the posterior plot as calculated from the surface gravity and radius.
    output : str
        Output filename.

    Returns
    -------
    NoneType
        None
    """

    mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
    mpl.rcParams['font.family'] = 'serif'

    plt.rc('axes', edgecolor='black', linewidth=2.2)

    if burnin is None:
        burnin = 0

    species_db = database.Database()
    box = species_db.get_samples(tag, burnin=burnin)

    print('Median sample:')
    for key, value in box.median_sample.items():
        print(f'   - {key} = {value:.2f}')

    samples = box.samples
    ndim = samples.shape[-1]

    if box.prob_sample is not None:
        par_val = tuple(box.prob_sample.values())

        print('Maximum posterior sample:')
        for key, value in box.prob_sample.items():
            print(f'   - {key} = {value:.2f}')

    print(f'Plotting the posterior: {output}...', end='', flush=True)

    if inc_luminosity:
        if 'teff' in box.parameters and 'radius' in box.parameters:
            teff_index = np.argwhere(np.array(box.parameters) == 'teff')[0]
            radius_index = np.argwhere(np.array(box.parameters) == 'radius')[0]

            luminosity = 4. * np.pi * (samples[..., radius_index]*constants.R_JUP)**2 * \
                constants.SIGMA_SB * samples[..., teff_index]**4. / constants.L_SUN

            samples = np.append(samples, np.log10(luminosity), axis=-1)
            box.parameters.append('luminosity')
            ndim += 1

        elif 'teff_0' in box.parameters and 'radius_0' in box.parameters:
            luminosity = 0.

            for i in range(100):
                teff_index = np.argwhere(
                    np.array(box.parameters) == f'teff_{i}')
                radius_index = np.argwhere(
                    np.array(box.parameters) == f'radius_{i}')

                if len(teff_index) > 0 and len(radius_index) > 0:
                    luminosity += 4. * np.pi * (samples[..., radius_index[0]]*constants.R_JUP)**2 \
                        * constants.SIGMA_SB * samples[..., teff_index[0]]**4. / constants.L_SUN

                else:
                    break

            samples = np.append(samples, np.log10(luminosity), axis=-1)
            box.parameters.append('luminosity')
            ndim += 1

            # teff_index = np.argwhere(np.array(box.parameters) == 'teff_0')
            # radius_index = np.argwhere(np.array(box.parameters) == 'radius_0')
            #
            # luminosity_0 = 4. * np.pi * (samples[..., radius_index[0]]*constants.R_JUP)**2 \
            #     * constants.SIGMA_SB * samples[..., teff_index[0]]**4. / constants.L_SUN
            #
            # samples = np.append(samples, np.log10(luminosity_0), axis=-1)
            # box.parameters.append('luminosity_0')
            # ndim += 1
            #
            # teff_index = np.argwhere(np.array(box.parameters) == 'teff_1')
            # radius_index = np.argwhere(np.array(box.parameters) == 'radius_1')
            #
            # luminosity_1 = 4. * np.pi * (samples[..., radius_index[0]]*constants.R_JUP)**2 \
            #     * constants.SIGMA_SB * samples[..., teff_index[0]]**4. / constants.L_SUN
            #
            # samples = np.append(samples, np.log10(luminosity_1), axis=-1)
            # box.parameters.append('luminosity_1')
            # ndim += 1
            #
            # teff_index_0 = np.argwhere(np.array(box.parameters) == 'teff_0')
            # radius_index_0 = np.argwhere(np.array(box.parameters) == 'radius_0')
            #
            # teff_index_1 = np.argwhere(np.array(box.parameters) == 'teff_1')
            # radius_index_1 = np.argwhere(np.array(box.parameters) == 'radius_1')
            #
            # luminosity_0 = 4. * np.pi * (samples[..., radius_index_0[0]]*constants.R_JUP)**2 \
            #     * constants.SIGMA_SB * samples[..., teff_index_0[0]]**4. / constants.L_SUN
            #
            # luminosity_1 = 4. * np.pi * (samples[..., radius_index_1[0]]*constants.R_JUP)**2 \
            #     * constants.SIGMA_SB * samples[..., teff_index_1[0]]**4. / constants.L_SUN
            #
            # samples = np.append(samples, np.log10(luminosity_0/luminosity_1), axis=-1)
            # box.parameters.append('luminosity_ratio')
            # ndim += 1

            # r_tmp = samples[..., radius_index_0[0]]*constants.R_JUP
            # lum_diff = (luminosity_1*constants.L_SUN-luminosity_0*constants.L_SUN)
            #
            # m_mdot = (3600.*24.*365.25)*lum_diff*r_tmp/constants.GRAVITY/constants.M_JUP**2
            #
            # samples = np.append(samples, m_mdot, axis=-1)
            # box.parameters.append('m_mdot')
            # ndim += 1

    if inc_mass:
        if 'logg' in box.parameters and 'radius' in box.parameters:
            logg_index = np.argwhere(np.array(box.parameters) == 'logg')[0]
            radius_index = np.argwhere(np.array(box.parameters) == 'radius')[0]

            mass_samples = read_util.get_mass(samples[..., logg_index],
                                              samples[..., radius_index])

            samples = np.append(samples, mass_samples, axis=-1)
            box.parameters.append('mass')
            ndim += 1

        else:
            warnings.warn(
                'Samples with the log(g) and radius are required for \'inc_mass=True\'.'
            )

    if isinstance(title_fmt, list) and len(title_fmt) != ndim:
        raise ValueError(
            f'The number of items in the list of \'title_fmt\' ({len(title_fmt)}) is '
            f'not equal to the number of dimensions of the samples ({ndim}).')

    labels = plot_util.update_labels(box.parameters)

    # Check if parameter values were fixed

    index_sel = []
    index_del = []

    # Use only last axis for parameter dimensions
    for i in range(ndim):
        if np.amin(samples[..., i]) == np.amax(samples[..., i]):
            index_del.append(i)
        else:
            index_sel.append(i)

    samples = samples[..., index_sel]

    for i in range(len(index_del) - 1, -1, -1):
        del labels[index_del[i]]

    ndim -= len(index_del)

    samples = samples.reshape((-1, ndim))

    hist_titles = []

    for i, item in enumerate(labels):
        unit_start = item.find('(')

        if unit_start == -1:
            param_label = item
            unit_label = None

        else:
            param_label = item[:unit_start]
            # Remove parenthesis from the units
            unit_label = item[unit_start + 1:-1]

        q_16, q_50, q_84 = corner.quantile(samples[:, i], [0.16, 0.5, 0.84])
        q_minus, q_plus = q_50 - q_16, q_84 - q_50

        if isinstance(title_fmt, str):
            fmt = '{{0:{0}}}'.format(title_fmt).format

        elif isinstance(title_fmt, list):
            fmt = '{{0:{0}}}'.format(title_fmt[i]).format

        best_fit = r'${{{0}}}_{{-{1}}}^{{+{2}}}$'
        best_fit = best_fit.format(fmt(q_50), fmt(q_minus), fmt(q_plus))

        if unit_label is None:
            hist_title = f'{param_label} = {best_fit}'

        else:
            hist_title = f'{param_label} = {best_fit} {unit_label}'

        hist_titles.append(hist_title)

    fig = corner.corner(samples,
                        quantiles=[0.16, 0.5, 0.84],
                        labels=labels,
                        label_kwargs={'fontsize': 13},
                        titles=hist_titles,
                        show_titles=True,
                        title_fmt=None,
                        title_kwargs={'fontsize': 12})

    axes = np.array(fig.axes).reshape((ndim, ndim))

    for i in range(ndim):
        for j in range(ndim):
            if i >= j:
                ax = axes[i, j]

                ax.xaxis.set_major_formatter(ScalarFormatter(useOffset=False))
                ax.yaxis.set_major_formatter(ScalarFormatter(useOffset=False))

                labelleft = j == 0 and i != 0
                labelbottom = i == ndim - 1

                ax.tick_params(axis='both',
                               which='major',
                               colors='black',
                               labelcolor='black',
                               direction='in',
                               width=1,
                               length=5,
                               labelsize=12,
                               top=True,
                               bottom=True,
                               left=True,
                               right=True,
                               labelleft=labelleft,
                               labelbottom=labelbottom,
                               labelright=False,
                               labeltop=False)

                ax.tick_params(axis='both',
                               which='minor',
                               colors='black',
                               labelcolor='black',
                               direction='in',
                               width=1,
                               length=3,
                               labelsize=12,
                               top=True,
                               bottom=True,
                               left=True,
                               right=True,
                               labelleft=labelleft,
                               labelbottom=labelbottom,
                               labelright=False,
                               labeltop=False)

                if limits is not None:
                    ax.set_xlim(limits[j])

                if max_posterior:
                    ax.axvline(par_val[j], color='tomato')

                if i > j:
                    if max_posterior:
                        ax.axhline(par_val[i], color='tomato')
                        ax.plot(par_val[j], par_val[i], 's', color='tomato')

                    if limits is not None:
                        ax.set_ylim(limits[i])

                if offset is not None:
                    ax.get_xaxis().set_label_coords(0.5, offset[0])
                    ax.get_yaxis().set_label_coords(offset[1], 0.5)

                else:
                    ax.get_xaxis().set_label_coords(0.5, -0.26)
                    ax.get_yaxis().set_label_coords(-0.27, 0.5)

    if title:
        fig.suptitle(title, y=1.02, fontsize=16)

    plt.savefig(os.getcwd() + '/' + output, bbox_inches='tight')
    plt.clf()
    plt.close()

    print(' [DONE]')
예제 #3
0
def plot_walkers(
    tag: str,
    nsteps: Optional[int] = None,
    offset: Optional[Tuple[float, float]] = None,
    output: Optional[str] = "walkers.pdf",
) -> None:
    """
    Function to plot the step history of the walkers.

    Parameters
    ----------
    tag : str
        Database tag with the samples.
    nsteps : int, None
        Number of steps that are plotted. All steps are plotted if set to ``None``.
    offset : tuple(float, float), None
        Offset of the x- and y-axis label. Default values are used if set to ``None``.
    output : str
        Output filename for the plot. The plot is shown in an
        interface window if the argument is set to ``None``.

    Returns
    -------
    NoneType
        None
    """

    if output is None:
        print("Plotting walkers...", end="", flush=True)
    else:
        print(f"Plotting walkers: {output}...", end="", flush=True)

    mpl.rcParams["font.serif"] = ["Bitstream Vera Serif"]
    mpl.rcParams["font.family"] = "serif"

    plt.rc("axes", edgecolor="black", linewidth=2.2)

    species_db = database.Database()
    box = species_db.get_samples(tag)

    samples = box.samples
    labels = plot_util.update_labels(box.parameters)

    if samples.ndim == 2:
        raise ValueError(
            f"The samples of '{tag}' have only 2 dimensions whereas 3 are required "
            f"for plotting the walkers. The plot_walkers function can only be "
            f"used after running the MCMC with run_mcmc and not after running "
            f"run_ultranest or run_multinest.")

    ndim = samples.shape[-1]

    plt.figure(1, figsize=(6, ndim * 1.5))
    gridsp = mpl.gridspec.GridSpec(ndim, 1)
    gridsp.update(wspace=0, hspace=0.1, left=0, right=1, bottom=0, top=1)

    for i in range(ndim):
        ax = plt.subplot(gridsp[i, 0])

        if i == ndim - 1:
            ax.tick_params(
                axis="both",
                which="major",
                colors="black",
                labelcolor="black",
                direction="in",
                width=1,
                length=5,
                labelsize=12,
                top=True,
                bottom=True,
                left=True,
                right=True,
                labelbottom=True,
            )

            ax.tick_params(
                axis="both",
                which="minor",
                colors="black",
                labelcolor="black",
                direction="in",
                width=1,
                length=3,
                labelsize=12,
                top=True,
                bottom=True,
                left=True,
                right=True,
                labelbottom=True,
            )

        else:
            ax.tick_params(
                axis="both",
                which="major",
                colors="black",
                labelcolor="black",
                direction="in",
                width=1,
                length=5,
                labelsize=12,
                top=True,
                bottom=True,
                left=True,
                right=True,
                labelbottom=False,
            )

            ax.tick_params(
                axis="both",
                which="minor",
                colors="black",
                labelcolor="black",
                direction="in",
                width=1,
                length=3,
                labelsize=12,
                top=True,
                bottom=True,
                left=True,
                right=True,
                labelbottom=False,
            )

        if i == ndim - 1:
            ax.set_xlabel("Step number", fontsize=10)
        else:
            ax.set_xlabel("", fontsize=10)

        ax.set_ylabel(labels[i], fontsize=10)

        if offset is not None:
            ax.get_xaxis().set_label_coords(0.5, offset[0])
            ax.get_yaxis().set_label_coords(offset[1], 0.5)

        else:
            ax.get_xaxis().set_label_coords(0.5, -0.22)
            ax.get_yaxis().set_label_coords(-0.09, 0.5)

        if nsteps is not None:
            ax.set_xlim(0, nsteps)

        for j in range(samples.shape[0]):
            ax.plot(samples[j, :, i], ls="-", lw=0.5, color="black", alpha=0.5)

    print(" [DONE]")

    if output is None:
        plt.show()
    else:
        plt.savefig(output, bbox_inches="tight")

    plt.clf()
    plt.close()
예제 #4
0
def plot_posterior(
    tag: str,
    burnin: Optional[int] = None,
    title: Optional[str] = None,
    offset: Optional[Tuple[float, float]] = None,
    title_fmt: Union[str, List[str]] = ".2f",
    limits: Optional[List[Tuple[float, float]]] = None,
    max_prob: bool = False,
    vmr: bool = False,
    inc_luminosity: bool = False,
    inc_mass: bool = False,
    inc_pt_param: bool = False,
    inc_loglike: bool = False,
    output: Optional[str] = "posterior.pdf",
) -> None:
    """
    Function to plot the posterior distribution of the fitted parameters.

    Parameters
    ----------
    tag : str
        Database tag with the samples.
    burnin : int, None
        Number of burnin steps to exclude. All samples are used if set to ``None``.
    title : str, None
        Plot title. No title is shown if set to ``None``.
    offset : tuple(float, float), None
        Offset of the x- and y-axis label. Default values are used if set to ``None``.
    title_fmt : str, list(str)
        Format of the titles above the 1D distributions. Either a single string, which will be used
        for all parameters, or a list with the title format for each parameter separately (in the
        order as shown in the corner plot).
    limits : list(tuple(float, float), ), None
        Axis limits of all parameters. Automatically set if set to ``None``.
    max_prob : bool
        Plot the position of the sample with the maximum posterior probability.
    vmr : bool
        Plot the volume mixing ratios (i.e. number fractions) instead of the mass fractions of the
        retrieved species with :class:`~species.analysis.retrieval.AtmosphericRetrieval`.
    inc_luminosity : bool
        Include the log10 of the luminosity in the posterior plot as calculated from the
        effective temperature and radius.
    inc_mass : bool
        Include the mass in the posterior plot as calculated from the surface gravity and radius.
    inc_pt_param : bool
        Include the parameters of the pressure-temperature profile. Only used if the ``tag``
        contains samples obtained with :class:`~species.analysis.retrieval.AtmosphericRetrieval`.
    inc_loglike : bool
        Include the log10 of the likelihood as additional parameter in the corner plot.
    output : str
        Output filename for the plot. The plot is shown in an
        interface window if the argument is set to ``None``.

    Returns
    -------
    NoneType
        None
    """

    mpl.rcParams["font.serif"] = ["Bitstream Vera Serif"]
    mpl.rcParams["font.family"] = "serif"

    plt.rc("axes", edgecolor="black", linewidth=2.2)

    if burnin is None:
        burnin = 0

    species_db = database.Database()

    box = species_db.get_samples(tag, burnin=burnin)
    samples = box.samples

    # index_sel = [0, 1, 8, 9, 14]
    # samples = samples[:, index_sel]
    #
    # for i in range(13, 9, -1):
    #     del box.parameters[i]
    #
    # del box.parameters[2]
    # del box.parameters[2]
    # del box.parameters[2]
    # del box.parameters[2]
    # del box.parameters[2]
    # del box.parameters[2]

    ndim = len(box.parameters)

    if not inc_pt_param and box.spectrum == "petitradtrans":
        pt_param = ["tint", "t1", "t2", "t3", "alpha", "log_delta"]

        index_del = []
        item_del = []

        for i in range(100):
            pt_item = f"t{i}"

            if pt_item in box.parameters:
                param_index = np.argwhere(
                    np.array(box.parameters) == pt_item)[0]
                index_del.append(param_index)
                item_del.append(pt_item)

            else:
                break

        for item in pt_param:
            if item in box.parameters and item not in item_del:
                param_index = np.argwhere(np.array(box.parameters) == item)[0]
                index_del.append(param_index)
                item_del.append(item)

        samples = np.delete(samples, index_del, axis=1)
        ndim -= len(index_del)

        for item in item_del:
            box.parameters.remove(item)

    if box.spectrum == "petitradtrans" and box.attributes[
            "chemistry"] == "free":
        box.parameters.append("c_h_ratio")
        box.parameters.append("o_h_ratio")
        box.parameters.append("c_o_ratio")

        ndim += 3

        abund_index = {}
        for i, item in enumerate(box.parameters):
            if item == "CH4":
                abund_index["CH4"] = i

            elif item == "CO":
                abund_index["CO"] = i

            elif item == "CO_all_iso":
                abund_index["CO_all_iso"] = i

            elif item == "CO_all_iso_HITEMP":
                abund_index["CO_all_iso_HITEMP"] = i

            elif item == "CO2":
                abund_index["CO2"] = i

            elif item == "FeH":
                abund_index["FeH"] = i

            elif item == "H2O":
                abund_index["H2O"] = i

            elif item == "H2O_HITEMP":
                abund_index["H2O_HITEMP"] = i

            elif item == "H2S":
                abund_index["H2S"] = i

            elif item == "Na":
                abund_index["Na"] = i

            elif item == "NH3":
                abund_index["NH3"] = i

            elif item == "K":
                abund_index["K"] = i

            elif item == "PH3":
                abund_index["PH3"] = i

            elif item == "TiO":
                abund_index["TiO"] = i

            elif item == "TiO_all_Exomol":
                abund_index["TiO_all_Exomol"] = i

            elif item == "VO":
                abund_index["VO"] = i

            elif item == "VO_Plez":
                abund_index["VO_Plez"] = i

        c_h_ratio = np.zeros(samples.shape[0])
        o_h_ratio = np.zeros(samples.shape[0])
        c_o_ratio = np.zeros(samples.shape[0])

        for i, item in enumerate(samples):
            abund = {}

            if "CH4" in box.parameters:
                abund["CH4"] = item[abund_index["CH4"]]

            if "CO" in box.parameters:
                abund["CO"] = item[abund_index["CO"]]

            if "CO_all_iso" in box.parameters:
                abund["CO_all_iso"] = item[abund_index["CO"]]

            if "CO_all_iso_HITEMP" in box.parameters:
                abund["CO_all_iso_HITEMP"] = item[
                    abund_index["CO_all_iso_HITEMP"]]

            if "CO2" in box.parameters:
                abund["CO2"] = item[abund_index["CO2"]]

            if "FeH" in box.parameters:
                abund["FeH"] = item[abund_index["FeH"]]

            if "H2O" in box.parameters:
                abund["H2O"] = item[abund_index["H2O"]]

            if "H2O_HITEMP" in box.parameters:
                abund["H2O_HITEMP"] = item[abund_index["H2O_HITEMP"]]

            if "H2S" in box.parameters:
                abund["H2S"] = item[abund_index["H2S"]]

            if "Na" in box.parameters:
                abund["Na"] = item[abund_index["Na"]]

            if "K" in box.parameters:
                abund["K"] = item[abund_index["K"]]

            if "NH3" in box.parameters:
                abund["NH3"] = item[abund_index["NH3"]]

            if "PH3" in box.parameters:
                abund["PH3"] = item[abund_index["PH3"]]

            if "TiO" in box.parameters:
                abund["TiO"] = item[abund_index["TiO"]]

            if "TiO_all_Exomol" in box.parameters:
                abund["TiO_all_Exomol"] = item[abund_index["TiO_all_Exomol"]]

            if "VO" in box.parameters:
                abund["VO"] = item[abund_index["VO"]]

            if "VO_Plez" in box.parameters:
                abund["VO_Plez"] = item[abund_index["VO_Plez"]]

            c_h_ratio[i], o_h_ratio[i], c_o_ratio[
                i] = retrieval_util.calc_metal_ratio(abund)

    if (vmr and box.spectrum == "petitradtrans"
            and box.attributes["chemistry"] == "free"):
        print("Changing mass fractions to number fractions...",
              end="",
              flush=True)

        # Get all available line species
        line_species = retrieval_util.get_line_species()

        # Get the atomic and molecular masses
        masses = retrieval_util.atomic_masses()

        # Create array for the updated samples
        updated_samples = np.zeros(samples.shape)

        for i, samples_item in enumerate(samples):
            # Initiate a dictionary for the log10 mass fraction of the metals
            log_x_abund = {}

            for param_item in box.parameters:
                if param_item in line_species:
                    # Get the index of the parameter
                    param_index = box.parameters.index(param_item)

                    # Store log10 mass fraction in the dictionary
                    log_x_abund[param_item] = samples_item[param_index]

            # Create a dictionary with all mass fractions, including H2 and He
            x_abund = retrieval_util.mass_fractions(log_x_abund)

            # Calculate the mean molecular weight from the input mass fractions
            mmw = retrieval_util.mean_molecular_weight(x_abund)

            for param_item in box.parameters:
                if param_item in line_species:
                    # Get the index of the parameter
                    param_index = box.parameters.index(param_item)

                    # Overwrite the sample with the log10 number fraction
                    samples_item[param_index] = np.log10(
                        10.0**samples_item[param_index] * mmw /
                        masses[param_item])

            # Store the updated sample to the array
            updated_samples[i, ] = samples_item

        # Overwrite the samples in the SamplesBox
        box.samples = updated_samples

        print(" [DONE]")

    print("Median sample:")
    for key, value in box.median_sample.items():
        print(f"   - {key} = {value:.2e}")

    if "gauss_mean" in box.parameters:
        param_index = np.argwhere(np.array(box.parameters) == "gauss_mean")[0]
        samples[:, param_index] *= 1e3  # (um) -> (nm)

    if "gauss_sigma" in box.parameters:
        param_index = np.argwhere(np.array(box.parameters) == "gauss_sigma")[0]
        samples[:, param_index] *= 1e3  # (um) -> (nm)

    if box.prob_sample is not None:
        par_val = tuple(box.prob_sample.values())

        print("Maximum posterior sample:")
        for key, value in box.prob_sample.items():
            print(f"   - {key} = {value:.2e}")

    for item in box.parameters:
        if item[0:11] == "wavelength_":
            param_index = box.parameters.index(item)

            # (um) -> (nm)
            box.samples[:, param_index] *= 1e3

    if output is None:
        print("Plotting the posterior...", end="", flush=True)
    else:
        print(f"Plotting the posterior: {output}...", end="", flush=True)

    if "H2O" in box.parameters or "H2O_HITEMP" in box.parameters:
        samples = np.column_stack((samples, c_h_ratio, o_h_ratio, c_o_ratio))

    if inc_luminosity:
        if "teff" in box.parameters and "radius" in box.parameters:
            teff_index = np.argwhere(np.array(box.parameters) == "teff")[0]
            radius_index = np.argwhere(np.array(box.parameters) == "radius")[0]

            lum_planet = (4.0 * np.pi *
                          (samples[..., radius_index] * constants.R_JUP)**2 *
                          constants.SIGMA_SB * samples[..., teff_index]**4.0 /
                          constants.L_SUN)

            if "disk_teff" in box.parameters and "disk_radius" in box.parameters:
                teff_index = np.argwhere(
                    np.array(box.parameters) == "disk_teff")[0]
                radius_index = np.argwhere(
                    np.array(box.parameters) == "disk_radius")[0]

                lum_disk = (4.0 * np.pi *
                            (samples[..., radius_index] * constants.R_JUP)**2 *
                            constants.SIGMA_SB *
                            samples[..., teff_index]**4.0 / constants.L_SUN)

                samples = np.append(samples,
                                    np.log10(lum_planet + lum_disk),
                                    axis=-1)
                box.parameters.append("luminosity")
                ndim += 1

                samples = np.append(samples, lum_disk / lum_planet, axis=-1)
                box.parameters.append("luminosity_disk_planet")
                ndim += 1

            else:
                samples = np.append(samples, np.log10(lum_planet), axis=-1)
                box.parameters.append("luminosity")
                ndim += 1

        elif "teff_0" in box.parameters and "radius_0" in box.parameters:
            luminosity = 0.0

            for i in range(100):
                teff_index = np.argwhere(
                    np.array(box.parameters) == f"teff_{i}")
                radius_index = np.argwhere(
                    np.array(box.parameters) == f"radius_{i}")

                if len(teff_index) > 0 and len(radius_index) > 0:
                    luminosity += (
                        4.0 * np.pi *
                        (samples[..., radius_index[0]] * constants.R_JUP)**2 *
                        constants.SIGMA_SB * samples[..., teff_index[0]]**4.0 /
                        constants.L_SUN)

                else:
                    break

            samples = np.append(samples, np.log10(luminosity), axis=-1)
            box.parameters.append("luminosity")
            ndim += 1

            # teff_index = np.argwhere(np.array(box.parameters) == 'teff_0')
            # radius_index = np.argwhere(np.array(box.parameters) == 'radius_0')
            #
            # luminosity_0 = 4. * np.pi * (samples[..., radius_index[0]]*constants.R_JUP)**2 \
            #     * constants.SIGMA_SB * samples[..., teff_index[0]]**4. / constants.L_SUN
            #
            # samples = np.append(samples, np.log10(luminosity_0), axis=-1)
            # box.parameters.append('luminosity_0')
            # ndim += 1
            #
            # teff_index = np.argwhere(np.array(box.parameters) == 'teff_1')
            # radius_index = np.argwhere(np.array(box.parameters) == 'radius_1')
            #
            # luminosity_1 = 4. * np.pi * (samples[..., radius_index[0]]*constants.R_JUP)**2 \
            #     * constants.SIGMA_SB * samples[..., teff_index[0]]**4. / constants.L_SUN
            #
            # samples = np.append(samples, np.log10(luminosity_1), axis=-1)
            # box.parameters.append('luminosity_1')
            # ndim += 1
            #
            # teff_index_0 = np.argwhere(np.array(box.parameters) == 'teff_0')
            # radius_index_0 = np.argwhere(np.array(box.parameters) == 'radius_0')
            #
            # teff_index_1 = np.argwhere(np.array(box.parameters) == 'teff_1')
            # radius_index_1 = np.argwhere(np.array(box.parameters) == 'radius_1')
            #
            # luminosity_0 = 4. * np.pi * (samples[..., radius_index_0[0]]*constants.R_JUP)**2 \
            #     * constants.SIGMA_SB * samples[..., teff_index_0[0]]**4. / constants.L_SUN
            #
            # luminosity_1 = 4. * np.pi * (samples[..., radius_index_1[0]]*constants.R_JUP)**2 \
            #     * constants.SIGMA_SB * samples[..., teff_index_1[0]]**4. / constants.L_SUN
            #
            # samples = np.append(samples, np.log10(luminosity_0/luminosity_1), axis=-1)
            # box.parameters.append('luminosity_ratio')
            # ndim += 1

            # r_tmp = samples[..., radius_index_0[0]]*constants.R_JUP
            # lum_diff = (luminosity_1*constants.L_SUN-luminosity_0*constants.L_SUN)
            #
            # m_mdot = (3600.*24.*365.25)*lum_diff*r_tmp/constants.GRAVITY/constants.M_JUP**2
            #
            # samples = np.append(samples, m_mdot, axis=-1)
            # box.parameters.append('m_mdot')
            # ndim += 1

    if inc_mass:
        if "logg" in box.parameters and "radius" in box.parameters:
            logg_index = np.argwhere(np.array(box.parameters) == "logg")[0]
            radius_index = np.argwhere(np.array(box.parameters) == "radius")[0]

            mass_samples = read_util.get_mass(samples[..., logg_index],
                                              samples[..., radius_index])

            samples = np.append(samples, mass_samples, axis=-1)
            box.parameters.append("mass")
            ndim += 1

        else:
            warnings.warn(
                "Samples with the log(g) and radius are required for 'inc_mass=True'."
            )

    if inc_loglike:
        # Get ln(L) of the samples
        ln_prob = box.ln_prob[..., np.newaxis]

        # Normalized by the maximum ln(L)
        ln_prob -= np.amax(ln_prob)

        # Convert ln(L) to log10(L)
        log_prob = ln_prob * np.exp(1.0)

        # Convert log10(L) to L
        prob = 10.0**log_prob

        # Normalize to an integrated probability of 1
        prob /= np.sum(prob)

        samples = np.append(samples, np.log10(prob), axis=-1)
        box.parameters.append("log_prob")
        ndim += 1

    labels = plot_util.update_labels(box.parameters)

    # Check if parameter values were fixed

    index_sel = []
    index_del = []

    for i in range(ndim):
        if np.amin(samples[:, i]) == np.amax(samples[:, i]):
            index_del.append(i)
        else:
            index_sel.append(i)

    samples = samples[:, index_sel]

    for i in range(len(index_del) - 1, -1, -1):
        del labels[index_del[i]]

    ndim -= len(index_del)

    samples = samples.reshape((-1, ndim))

    if isinstance(title_fmt, list) and len(title_fmt) != ndim:
        raise ValueError(
            f"The number of items in the list of 'title_fmt' ({len(title_fmt)}) is "
            f"not equal to the number of dimensions of the samples ({ndim}).")

    hist_titles = []

    for i, item in enumerate(labels):
        unit_start = item.find("(")

        if unit_start == -1:
            param_label = item
            unit_label = None

        else:
            param_label = item[:unit_start]
            # Remove parenthesis from the units
            unit_label = item[unit_start + 1:-1]

        q_16, q_50, q_84 = corner.quantile(samples[:, i], [0.16, 0.5, 0.84])
        q_minus, q_plus = q_50 - q_16, q_84 - q_50

        if isinstance(title_fmt, str):
            fmt = "{{0:{0}}}".format(title_fmt).format

        elif isinstance(title_fmt, list):
            fmt = "{{0:{0}}}".format(title_fmt[i]).format

        best_fit = r"${{{0}}}_{{-{1}}}^{{+{2}}}$"
        best_fit = best_fit.format(fmt(q_50), fmt(q_minus), fmt(q_plus))

        if unit_label is None:
            hist_title = f"{param_label} = {best_fit}"

        else:
            hist_title = f"{param_label} = {best_fit} {unit_label}"

        hist_titles.append(hist_title)

    fig = corner.corner(
        samples,
        quantiles=[0.16, 0.5, 0.84],
        labels=labels,
        label_kwargs={"fontsize": 13},
        titles=hist_titles,
        show_titles=True,
        title_fmt=None,
        title_kwargs={"fontsize": 12},
    )

    axes = np.array(fig.axes).reshape((ndim, ndim))

    for i in range(ndim):
        for j in range(ndim):
            if i >= j:
                ax = axes[i, j]

                ax.xaxis.set_major_formatter(ScalarFormatter(useOffset=False))
                ax.yaxis.set_major_formatter(ScalarFormatter(useOffset=False))

                labelleft = j == 0 and i != 0
                labelbottom = i == ndim - 1

                ax.tick_params(
                    axis="both",
                    which="major",
                    colors="black",
                    labelcolor="black",
                    direction="in",
                    width=1,
                    length=5,
                    labelsize=12,
                    top=True,
                    bottom=True,
                    left=True,
                    right=True,
                    labelleft=labelleft,
                    labelbottom=labelbottom,
                    labelright=False,
                    labeltop=False,
                )

                ax.tick_params(
                    axis="both",
                    which="minor",
                    colors="black",
                    labelcolor="black",
                    direction="in",
                    width=1,
                    length=3,
                    labelsize=12,
                    top=True,
                    bottom=True,
                    left=True,
                    right=True,
                    labelleft=labelleft,
                    labelbottom=labelbottom,
                    labelright=False,
                    labeltop=False,
                )

                if limits is not None:
                    ax.set_xlim(limits[j])

                if max_prob:
                    ax.axvline(par_val[j], color="tomato")

                if i > j:
                    if max_prob:
                        ax.axhline(par_val[i], color="tomato")
                        ax.plot(par_val[j], par_val[i], "s", color="tomato")

                    if limits is not None:
                        ax.set_ylim(limits[i])

                if offset is not None:
                    ax.get_xaxis().set_label_coords(0.5, offset[0])
                    ax.get_yaxis().set_label_coords(offset[1], 0.5)

                else:
                    ax.get_xaxis().set_label_coords(0.5, -0.26)
                    ax.get_yaxis().set_label_coords(-0.27, 0.5)

    if title:
        fig.suptitle(title, y=1.02, fontsize=16)

    print(" [DONE]")

    if output is None:
        plt.show()
    else:
        plt.savefig(output, bbox_inches="tight")

    plt.clf()
    plt.close()
예제 #5
0
def plot_walkers(tag, output, nsteps=None, offset=None):
    """
    Function to plot the step history of the walkers.

    Parameters
    ----------
    tag : str
        Database tag with the MCMC samples.
    output : str
        Output filename.
    nsteps : int
        Number of steps.
    offset : tuple(float, float)
        Offset of the x- and y-axis label.

    Returns
    -------
    None
    """

    sys.stdout.write('Plotting walkers: ' + output + '...')
    sys.stdout.flush()

    species_db = database.Database()
    box = species_db.get_samples(tag)

    samples = box.samples
    labels = plot_util.update_labels(box.parameters)

    ndim = samples.shape[-1]

    plt.figure(1, figsize=(6, 5))
    gridsp = mpl.gridspec.GridSpec(4, 1)
    gridsp.update(wspace=0, hspace=0.1, left=0, right=1, bottom=0, top=1)

    for i in range(ndim):
        ax = plt.subplot(gridsp[i, 0])

        ax.grid(True,
                linestyle=':',
                linewidth=0.7,
                color='silver',
                dashes=(1, 4))

        if i == ndim - 1:
            ax.tick_params(axis='both',
                           which='major',
                           colors='black',
                           labelcolor='black',
                           direction='in',
                           width=0.8,
                           length=5,
                           labelsize=12,
                           top=True,
                           bottom=True,
                           left=True,
                           right=True,
                           labelbottom=True)

            ax.tick_params(axis='both',
                           which='minor',
                           colors='black',
                           labelcolor='black',
                           direction='in',
                           width=0.8,
                           length=3,
                           labelsize=12,
                           top=True,
                           bottom=True,
                           left=True,
                           right=True,
                           labelbottom=True)

        else:
            ax.tick_params(axis='both',
                           which='major',
                           colors='black',
                           labelcolor='black',
                           direction='in',
                           width=0.8,
                           length=5,
                           labelsize=12,
                           top=True,
                           bottom=True,
                           left=True,
                           right=True,
                           labelbottom=False)

            ax.tick_params(axis='both',
                           which='minor',
                           colors='black',
                           labelcolor='black',
                           direction='in',
                           width=0.8,
                           length=3,
                           labelsize=12,
                           top=True,
                           bottom=True,
                           left=True,
                           right=True,
                           labelbottom=False)

        if i == ndim - 1:
            ax.set_xlabel('Step number', fontsize=10)
        else:
            ax.set_xlabel('', fontsize=10)

        ax.set_ylabel(labels[i], fontsize=10)

        if offset:
            ax.get_xaxis().set_label_coords(0.5, offset[0])
            ax.get_yaxis().set_label_coords(offset[1], 0.5)
        else:
            ax.get_xaxis().set_label_coords(0.5, -0.22)
            ax.get_yaxis().set_label_coords(-0.09, 0.5)

        if nsteps:
            ax.set_xlim(0, nsteps)

        for j in range(samples.shape[0]):
            ax.plot(samples[j, :, i], ls='-', lw=0.5, color='black', alpha=0.5)

    plt.savefig(os.getcwd() + '/' + output, bbox_inches='tight')
    plt.close()

    sys.stdout.write(' [DONE]\n')
    sys.stdout.flush()
예제 #6
0
def plot_posterior(tag,
                   burnin,
                   output,
                   title=None,
                   offset=None,
                   title_fmt='.2f',
                   limits=None):
    """
    Function to plot the posterior distributions.

    Parameters
    ----------
    tag : str
        Database tag with the MCMC samples.
    burnin : int
        Number of burnin steps to exclude.
    output : str
        Output filename.
    title : str
        Plot title.
    offset : tuple(float, float)
        Offset of the x- and y-axis label.
    title_fmt : str
        Format of the median and error values.
    limits : tuple(tuple(float, float), )
        Axis limits of all parameters. Automatically set if set to None.

    Returns
    -------
    None
    """

    sys.stdout.write('Plotting posteriors: ' + output + '...')
    sys.stdout.flush()

    species_db = database.Database()
    box = species_db.get_samples(tag)

    samples = box.samples
    par_val = box.chisquare

    labels = plot_util.update_labels(box.parameters)

    ndim = samples.shape[-1]

    samples = samples[:, int(burnin):, :].reshape((-1, ndim))

    fig = corner.corner(samples,
                        labels=labels,
                        quantiles=[0.16, 0.5, 0.84],
                        label_kwargs={'fontsize': 13},
                        show_titles=True,
                        title_kwargs={'fontsize': 12},
                        title_fmt=title_fmt)

    axes = np.array(fig.axes).reshape((ndim, ndim))

    for i in range(ndim):
        for j in range(ndim):
            if i >= j:
                ax = axes[i, j]

                ax.tick_params(axis='both',
                               which='major',
                               colors='black',
                               labelcolor='black',
                               direction='in',
                               width=0.8,
                               length=5,
                               labelsize=12,
                               top=True,
                               bottom=True,
                               left=True,
                               right=True)

                ax.tick_params(axis='both',
                               which='minor',
                               colors='black',
                               labelcolor='black',
                               direction='in',
                               width=0.8,
                               length=3,
                               labelsize=12,
                               top=True,
                               bottom=True,
                               left=True,
                               right=True)

                if limits is not None:
                    ax.set_xlim(limits[j])

                ax.axvline(par_val[j], color='tomato')

                if i > j:
                    ax.axhline(par_val[i], color='tomato')
                    ax.plot(par_val[j], par_val[i], 's', color='tomato')

                    if limits is not None:
                        ax.set_ylim(limits[i])

        if i >= j:
            if offset:
                ax.get_xaxis().set_label_coords(0.5, offset[0])
                ax.get_yaxis().set_label_coords(offset[1], 0.5)
            else:
                ax.get_xaxis().set_label_coords(0.5, -0.26)
                ax.get_yaxis().set_label_coords(-0.27, 0.5)

    if title:
        fig.suptitle(title, y=1.02, fontsize=16)

    plt.savefig(os.getcwd() + '/' + output, bbox_inches='tight')
    plt.close()

    sys.stdout.write(' [DONE]\n')
    sys.stdout.flush()