コード例 #1
0
    def test_read_sed_data(self):
        """
        Read in the sed grid from a cached file and test that selected values
        are as expected.
        """
        requested_params = ["Av", "Rv", "f_A", "M_ini", "logA", "Z", "distance"]

        # check that when return_params=True, then just a list of parameters is returned
        sparams = read_sed_data(self.seds_trim_fname_cache, return_params=True)
        assert isinstance(sparams, list), "Returned params are not a list"
        checknames = requested_params + ["seds", "lamb"]
        for cname in checknames:
            assert cname in sparams, f"{cname} not in sed parameter list"

        # check that otherwise, the requested sed data is returned
        sdata = read_sed_data(self.seds_trim_fname_cache, param_list=requested_params)
        expected_values = {
            "Av": 0.0,
            "Rv": 2.0,
            "f_A": 1.0,
            "M_ini": 4.0073261261,
            "logA": 6.0,
            "Z": 0.008,
            "distance": 783429.642766212,
        }
        for cname in requested_params:
            assert cname in sdata.keys(), f"requsted parameter {cname} not in sed data"
            np.testing.assert_allclose(
                sdata[cname][10],
                expected_values[cname],
                err_msg=f"expected value of {cname} is not found",
            )
コード例 #2
0
    def test_get_lnp_grid_vals(self):
        """
        Read in the lnp and sed grid data from cached files and test that
        selected values are as expected.
        """
        ldata = read_lnp_data(self.lnp_fname_cache)

        requested_params = [
            "Av", "Rv", "f_A", "M_ini", "logA", "Z", "distance"
        ]
        sdata = read_sed_data(self.seds_trim_fname_cache,
                              param_list=requested_params)

        lgvals_data = get_lnp_grid_vals(sdata, ldata)

        # check that otherwise, the requested lgvals data is returned
        expected_values = {
            "Av": [0.0, 0.0, 0.0, 0.0, 0.0],
            "Rv": [2.0, 2.0, 2.0, 2.0, 2.0],
            "f_A": [1.0, 1.0, 1.0, 1.0, 1.0],
            "M_ini":
            [3.89416909, 3.92726111, 3.95603228, 2.04966068, 2.04999995],
            "logA": [6.0, 6.0, 6.0, 9.0, 9.0],
            "Z": [0.03, 0.03, 0.03, 0.004, 0.004],
            "distance": [
                783429.64276621,
                783429.64276621,
                783429.64276621,
                783429.64276621,
                783429.64276621,
            ],
        }
        for cname in requested_params:
            assert (cname in lgvals_data.keys()
                    ), f"requsted parameter {cname} not in sed data"
            np.testing.assert_allclose(
                lgvals_data[cname][0:5, 10],
                expected_values[cname],
                err_msg=f"expected value of {cname} is not found",
            )
コード例 #3
0
ファイル: megabeast.py プロジェクト: lea-hagen/megabeast
def megabeast(megabeast_input_file, verbose=True):
    """
    Run the MegaBEAST on each of the spatially-reordered BEAST outputs.

    Parameters
    ----------
    megabeast_input_file : string
        Name of the file that contains settings, filenames, etc

    verbose : boolean (default=True)
        print extra info

    """
    # read in the settings from the file
    mb_settings = read_megabeast_input(megabeast_input_file)

    # setup the megabeast model including defining the priors
    #   - dust distribution model
    #   - stellar populations model (later)

    # use nstars image to setup for each pixel
    nstars_image, nstars_header = fits.getdata(mb_settings["nstars_filename"],
                                               header=True)
    n_x, n_y = nstars_image.shape

    # read in the beast data that is needed by all the pixels
    beast_data = {}
    # - SED data
    beast_data.update(
        read_beast_data.read_sed_data(
            mb_settings["beast_seds_filename"],
            param_list=["Av"]  # , "Rv", "f_A"]
        ))
    # - max completeness
    beast_data.update(
        read_beast_data.read_noise_data(
            mb_settings["beast_noise_filename"],
            param_list=["completeness"],
        ))
    beast_data["completeness"] = np.max(beast_data["completeness"], axis=1)

    # setup for output
    pixel_fit_status = np.full((n_x, n_y), False, dtype=bool)
    n_fit_params = len(mb_settings["fit_param_names"])
    best_fit_images = np.zeros((n_x, n_y, n_fit_params), dtype=float) + np.nan

    # loop over the pixels with non-zero entries in the nstars image
    for i in trange(n_x, desc="x pixels"):
        for j in trange(n_y, desc="y pixels", leave=False):
            # for i in [6]:
            #    for j in [6]:
            if verbose:
                print("working on (%i,%i)" % (i, j))
            if nstars_image[i, j] >= mb_settings["min_for_fit"]:
                pixel_fit_status[i, j] = True
                # get the saved sparse likelihoods
                lnp_filename = mb_settings[
                    "lnp_file_prefix"] + "_{0}_{1}_lnp.hd5".format(j, i)
                lnp_data = read_beast_data.read_lnp_data(
                    lnp_filename,
                    nstars=nstars_image[i, j],
                    shift_lnp=True,
                )

                # get the completeness and BEAST model parameters for the
                #   same grid points as the sparse likelihoods
                lnp_grid_vals = read_beast_data.get_lnp_grid_vals(
                    beast_data, lnp_data)

                # initialize the ensemble model with the parameters used
                # for the saved BEAST model run results
                #   currently only dust parameters allowed
                #   for testing -> only Av
                avs = lnp_grid_vals["Av"]
                rvs = [3.1]  # beast_data['Rv']
                fAs = [1.0]  # beast_data['f_A']
                beast_dust_priors = PriorWeightsDust(
                    avs,
                    mb_settings["av_prior_model"],
                    rvs,
                    mb_settings["rv_prior_model"],
                    fAs,
                    mb_settings["fA_prior_model"],
                )

                # standard minimization to find initial values
                def chi2(args):
                    return -1.0 * lnprob(*args)

                result = op.minimize(
                    chi2,
                    [0.25, 2.0, 0.5, 0.5, 1],
                    args=(beast_dust_priors, lnp_data, lnp_grid_vals),
                    method="Nelder-Mead",
                )
                best_fit_images[i, j, :] = result["x"]
                # print(result)
                # print(result['x'])
                # print(result['success'])

                # then run through MCMC to fully sample likelihood
                #    include option not to run MCMC

    # output results
    #    - best fit
    #    - megabeast parameter 1D pPDFs
    #    - MCMC chain

    master_header = nstars_header
    # Now, write the maps to disk

    # check that the directory exists
    if not os.path.exists("./" + mb_settings["projectname"] + "_megabeast/"):
        os.makedirs("./" + mb_settings["projectname"] + "_megabeast/")

    for k, cname in enumerate(mb_settings["fit_param_names"]):

        hdu = fits.PrimaryHDU(best_fit_images[:, :, k], header=master_header)

        # Save to FITS file
        hdu.writeto(
            "%s_megabeast/%s_%s_bestfit.fits" %
            (mb_settings["projectname"], mb_settings["projectname"], cname),
            overwrite=True,
        )
コード例 #4
0
def test_remove_filters():
    """
    Test for remove_filters.py

    The SED grid for this test has two entries for F475W: HST_ACS_WFC_F475W and
    HST_WFC3_F475W (the actual F475W observations are with ACS). This tests four
    combinations of running remove_filters:

    1. Use the catalog to choose which filters in the SED grid to keep
        A. without using beast_filt keyword: the code doesn't know what F475W in
           the catalog means, so it won't delete either F475W entry from the
           grid
        B. using beast_filt='HST_ACS_WFC_F475W': the code knows that when it
           sees F475W in the catalog, it means ACS, so it should remove
           HST_WFC3_F475W from the grid

    2. Use the rm_filters='F475W' keyword to choose filter removal
        A. without using beast_filt keyword: any time the code sees F475W in the
           grid, it will delete it
        B. using beast_filt='HST_WFC3_F475W': the code knows that F475W in
           rm_filters means WFC3, so it will only delete the F475W WFC3 entry in
           the grid
    """

    # download the needed files
    obs_fname = download_rename("phat_small/b15_4band_det_27_A.fits")
    seds_fname = download_rename(
        "phat_small/beast_example_phat_seds_extrafilter.hd5")

    # name to use for the output grid
    temp_physgrid_file = "temp_newgrid.hd5"

    # ==== case 1A ====

    # run filter removal
    remove_filters.remove_filters_from_files(
        obs_fname,
        physgrid=seds_fname,
        physgrid_outfile=temp_physgrid_file,
        # beast_filt=['HST_ACS_WFC_F475W'],
    )

    # check that the proper filters are retained
    expected_filters = [
        "HST_WFC3_F275W",
        "HST_WFC3_F336W",
        "HST_ACS_WFC_F475W",
        "HST_WFC3_F475W",
        "HST_ACS_WFC_F814W",
        "HST_WFC3_F110W",
        "HST_WFC3_F160W",
    ]
    temp = read_beast_data.read_sed_data(temp_physgrid_file,
                                         param_list=["filters"])
    assert set(temp["filters"]) == set(
        expected_filters), "remove_filters case 1A doesn't match"

    # remove temp file
    os.remove(temp_physgrid_file)

    # ==== case 1B ====

    # run filter removal
    remove_filters.remove_filters_from_files(
        obs_fname,
        physgrid=seds_fname,
        physgrid_outfile=temp_physgrid_file,
        beast_filt=["HST_ACS_WFC_F475W"],
    )

    # check that the proper filters are retained
    expected_filters = [
        "HST_WFC3_F275W",
        "HST_WFC3_F336W",
        "HST_ACS_WFC_F475W",
        # 'HST_WFC3_F475W',
        "HST_ACS_WFC_F814W",
        "HST_WFC3_F110W",
        "HST_WFC3_F160W",
    ]
    temp = read_beast_data.read_sed_data(temp_physgrid_file,
                                         param_list=["filters"])
    assert set(temp["filters"]) == set(
        expected_filters), "remove_filters case 1B doesn't match"

    # remove temp file
    os.remove(temp_physgrid_file)

    # ==== case 2A ====

    # run filter removal
    remove_filters.remove_filters_from_files(
        obs_fname,
        physgrid=seds_fname,
        physgrid_outfile=temp_physgrid_file,
        rm_filters=["F475W"],
        # beast_filt=['HST_WFC3_F475W'],
    )

    # check that the proper filters are retained
    expected_filters = [
        "HST_WFC3_F275W",
        "HST_WFC3_F336W",
        # 'HST_ACS_WFC_F475W',
        # 'HST_WFC3_F475W',
        "HST_ACS_WFC_F814W",
        "HST_WFC3_F110W",
        "HST_WFC3_F160W",
    ]
    temp = read_beast_data.read_sed_data(temp_physgrid_file,
                                         param_list=["filters"])
    assert set(temp["filters"]) == set(
        expected_filters), "remove_filters case 2A doesn't match"

    # remove temp file
    os.remove(temp_physgrid_file)

    # ==== case 2B ====

    # run filter removal
    remove_filters.remove_filters_from_files(
        obs_fname,
        physgrid=seds_fname,
        physgrid_outfile=temp_physgrid_file,
        rm_filters=["F475W"],
        beast_filt=["HST_WFC3_F475W"],
    )

    # check that the proper filters are retained
    expected_filters = [
        "HST_WFC3_F275W",
        "HST_WFC3_F336W",
        "HST_ACS_WFC_F475W",
        # 'HST_WFC3_F475W',
        "HST_ACS_WFC_F814W",
        "HST_WFC3_F110W",
        "HST_WFC3_F160W",
    ]
    temp = read_beast_data.read_sed_data(temp_physgrid_file,
                                         param_list=["filters"])
    assert set(temp["filters"]) == set(
        expected_filters), "remove_filters case 2B doesn't match"

    # remove temp file
    os.remove(temp_physgrid_file)
コード例 #5
0
def megabeast_image(megabeast_input_file, verbose=True):
    """
    Run the MegaBEAST on an image of BEAST results.  The BEAST results
    are given as spatially-reordered BEAST outputs with a file of lnp results
    for each pixel in the image.

    Parameters
    ----------
    megabeast_input_file : string
        Name of the file that contains settings, filenames, etc

    verbose : boolean (default=True)
        print extra info

    """
    # read in the settings from the file
    params = read_input(megabeast_input_file)

    # use nstars image to setup for each pixel
    nstars_image, nstars_header = fits.getdata(params["nstars_filename"],
                                               header=True)
    n_x, n_y = nstars_image.shape

    # read in the beast data that is needed by all the pixels
    beast_data = {}
    # - SED data
    beast_data.update(
        read_sed_data(params["beast_seds_filename"], param_list=["Av"]))
    # - max completeness
    beast_data.update(
        read_noise_data(
            params["beast_noise_filename"],
            param_list=["completeness"],
        ))
    # completeness from toothpick model so n band completeness values
    # require only 1 completeness value for each model
    # max picked (may not be correct)
    beast_data["completeness"] = np.max(beast_data["completeness"], axis=1)

    # BEAST prior model
    beast_pmodel = {}
    beast_pmodel["AV"] = params["av_prior_model"]
    beast_pmodel["RV"] = params["rv_prior_model"]
    beast_pmodel["fA"] = params["fA_prior_model"]

    # setup for output
    pixel_fit_status = np.full((n_x, n_y), False, dtype=bool)
    n_fit_params = len(params["fit_param_names"])
    best_fit_images = np.zeros((n_x, n_y, n_fit_params), dtype=float) + np.nan

    # loop over the pixels with non-zero entries in the nstars image
    for i in trange(n_x, desc="x pixels"):
        for j in trange(n_y, desc="y pixels", leave=False):

            if nstars_image[i, j] >= params["min_for_fit"]:
                pixel_fit_status[i, j] = True

                # filename with saved BEAST posteriors
                lnp_prefix = params["lnp_file_prefix"]
                lnp_filename = f"{lnp_prefix}_{j}_{i}_lnp.hd5"

                best_fit_params = fit_ensemble(
                    beast_data,
                    lnp_filename,
                    beast_pmodel,
                    nstars_expected=nstars_image[i, j],
                )

                best_fit_images[i, j, :] = best_fit_params

    # output results (* = future)
    #    - best fit
    #    - *megabeast parameter 1D pPDFs
    #    - *MCMC chain

    # Write the maps to disk
    master_header = nstars_header

    # check that the directory exists
    dpath = "./%s_megabeast/" % (params["projectname"])
    if not os.path.exists(dpath):
        os.makedirs(dpath)

    for k, cname in enumerate(params["fit_param_names"]):
        hdu = fits.PrimaryHDU(best_fit_images[:, :, k], header=master_header)
        hdu.writeto(
            "%s_megabeast/%s_%s_bestfit.fits" %
            (params["projectname"], params["projectname"], cname),
            overwrite=True,
        )
コード例 #6
0
ファイル: plot_input_data.py プロジェクト: petiay/megabeast
def plot_input_data(megabeast_input_file, chi2_plot=[], log_scale=False):
    """
    Parameters
    ----------
    megabeast_input_file : string
        Name of the file that contains settings, filenames, etc

    chi2_plot : list of floats (default=[])
        Make A_V histogram(s) with chi2 less than each of the values in this list

    log_scale : boolean (default=False)
        If True, make the histogram x-axis a log scale (to visualize log-normal
        A_V distribution)

    """

    # read in the settings from the file
    mb_settings = read_input(megabeast_input_file)

    # get the project name
    projectname = mb_settings["projectname"]

    # read in the beast data that is needed by all the pixels
    beast_data = {}
    # - SED data
    beast_data.update(
        read_beast_data.read_sed_data(
            mb_settings["beast_seds_filename"],
            param_list=["Av"]  # , "Rv", "f_A"]
        ))
    # - max completeness
    beast_data.update(
        read_beast_data.read_noise_data(
            mb_settings["beast_noise_filename"],
            param_list=["completeness"],
        ))
    beast_data["completeness"] = np.max(beast_data["completeness"], axis=1)

    # read in the nstars image
    nstars_image, nstars_header = fits.getdata(mb_settings["nstars_filename"],
                                               header=True)
    # dimensions of images/plotting
    y_dimen = nstars_image.shape[0]
    x_dimen = nstars_image.shape[1]

    # set up multi-page figure
    if not log_scale:
        pp = PdfPages("{0}_megabeast/plot_input_data.pdf".format(projectname))
    if log_scale:
        pp = PdfPages(
            "{0}_megabeast/plot_input_data_log.pdf".format(projectname))

    # save the best-fit A_V
    best_av = [[[] for j in range(x_dimen)] for i in range(y_dimen)]
    best_av_chi2 = [[[] for j in range(x_dimen)] for i in range(y_dimen)]

    # -----------------
    # Completeness vs A_V
    # -----------------

    print("")
    print("Making completeness/Av plot")
    print("")

    # set up figure
    plt.figure(figsize=(6, 6))
    plt.subplot(1, 1, 1)

    for i in tqdm(range(y_dimen), desc="y pixels"):
        for j in tqdm(range(x_dimen), desc="x pixels"):
            # for i in tqdm(range(int(y_dimen/3)), desc='y pixels'):
            #    for j in tqdm(range(int(x_dimen/3)), desc='x pixels'):
            # for i in [0]:
            #    for j in [12]:

            if nstars_image[i, j] > 20:

                # get info about the fits
                lnp_filename = mb_settings[
                    "lnp_file_prefix"] + "_{0}_{1}_lnp.hd5".format(j, i)
                lnp_data = read_beast_data.read_lnp_data(
                    lnp_filename,
                    nstars=nstars_image[i, j],
                    shift_lnp=True,
                )

                # get the completeness and BEAST model parameters for the
                #   same grid points as the sparse likelihoods
                lnp_grid_vals = read_beast_data.get_lnp_grid_vals(
                    beast_data, lnp_data)

                # grab the things we want to plot
                plot_av = lnp_grid_vals["Av"]
                plot_comp = lnp_grid_vals["completeness"]

                for n in range(nstars_image[i, j]):

                    # plot a random subset of the AVs and completenesses
                    if (i % 3 == 0) and (j % 3 == 0):
                        plot_these = np.random.choice(plot_av[:, n].size,
                                                      size=20,
                                                      replace=False)
                        plt.plot(
                            plot_av[plot_these, n] +
                            np.random.normal(scale=0.02, size=plot_these.size),
                            plot_comp[plot_these, n],
                            marker=".",
                            c="black",
                            ms=3,
                            mew=0,
                            linestyle="None",
                            alpha=0.05,
                        )

                    # also overplot the values for the best fit
                    max_ind = np.where(lnp_data["vals"][:, n] == np.max(
                        lnp_data["vals"][:, n]))[0][0]
                    best_av[i][j].append(plot_av[max_ind, n])
                    best_av_chi2[i][j].append(-2 *
                                              np.max(lnp_data["vals"][:, n]))
                    if (i % 3 == 0) and (j % 3 == 0):
                        plt.plot(
                            plot_av[max_ind, n] + np.random.normal(scale=0.01),
                            plot_comp[max_ind, n],
                            marker=".",
                            c="magenta",
                            ms=2,
                            mew=0,
                            linestyle="None",
                            alpha=0.3,
                            zorder=9999,
                        )

    ax = plt.gca()
    ax.set_xlabel(r"$A_V$")
    ax.set_ylabel("Completeness")

    pp.savefig()

    # -----------------
    # histograms of AVs
    # -----------------

    print("")
    print("Making Av Histograms")
    print("")

    # set up figure
    plt.figure(figsize=(x_dimen * 2, y_dimen * 2))

    # flat list of A_V
    # https://stackoverflow.com/questions/952914/making-a-flat-list-out-of-list-of-lists-in-python
    flat_av = [i for sublist in best_av for item in sublist for i in item]
    # grab the max A_V of all of them
    # max_av = max(flat_av)
    # define bins
    if not log_scale:
        uniq_av = np.unique(flat_av)
        gap = np.min(np.diff(uniq_av))
        bins = np.arange(uniq_av[0], uniq_av[-1], gap)
    if log_scale:
        uniq_av = np.unique(np.log10(flat_av))
        gap = (uniq_av[-1] - uniq_av[0]) / len(uniq_av)
        bins = np.arange(uniq_av[0], uniq_av[-1], gap)

    for i in tqdm(range(y_dimen), desc="y pixels"):
        for j in tqdm(range(x_dimen), desc="x pixels"):
            # for i in [0]:
            #    for j in [12]:

            if nstars_image[i, j] > 20:

                # set up the subplot
                plt.subplot(y_dimen, x_dimen,
                            (y_dimen - i - 1) * (x_dimen) + j + 1)

                # make a histogram
                if best_av[i][j] != []:
                    if not log_scale:
                        plt.hist(
                            best_av[i][j],
                            bins=bins.size,
                            range=(uniq_av[0] - gap / 2,
                                   uniq_av[-1] + gap / 2),
                            facecolor="xkcd:azure",
                            linewidth=0.25,
                            edgecolor="xkcd:azure",
                        )
                    if log_scale:
                        plt.hist(
                            np.log10(best_av[i][j]),
                            bins=bins.size,
                            range=(uniq_av[0] - gap / 2,
                                   uniq_av[-1] + gap / 2),
                            facecolor="xkcd:azure",
                            linewidth=0.25,
                            edgecolor="xkcd:azure",
                        )
                    # plt.xlim(xmax=max_av)

    plt.suptitle(r"Best-fit $A_V$ for each pixel", fontsize=40)

    pp.savefig()

    # -----------------
    # histograms of AVs with a chi2 cut
    # -----------------

    if len(chi2_plot) > 0:
        print("")
        print("Making Av Histograms with chi^2 cut")
        print("")

    for chi2_cut in chi2_plot:

        # set up figure
        plt.figure(figsize=(x_dimen * 2, y_dimen * 2))

        for i in tqdm(range(y_dimen), desc="y pixels"):
            for j in tqdm(range(x_dimen), desc="x pixels"):
                # for i in [0]:
                #    for j in [12]:

                if nstars_image[i, j] > 20:

                    # set up the subplot
                    plt.subplot(y_dimen, x_dimen,
                                (y_dimen - i - 1) * (x_dimen) + j + 1)

                    # make a histogram
                    if best_av[i][j] != []:
                        if not log_scale:
                            plot_av = np.array(best_av[i][j])[
                                np.array(best_av_chi2[i][j]) < chi2_cut]
                        if log_scale:
                            plot_av = np.log10(
                                np.array(best_av[i][j])[
                                    np.array(best_av_chi2[i][j]) < chi2_cut])
                        if len(plot_av) != 0:
                            plt.hist(
                                plot_av,
                                bins=bins.size,
                                range=(uniq_av[0] - gap / 2,
                                       uniq_av[-1] + gap / 2),
                                facecolor="xkcd:azure",
                                linewidth=0.25,
                                edgecolor="xkcd:azure",
                            )

        plt.suptitle(
            r"Best-fit $A_V$ for each pixel, but only using sources with $\chi^2 < $"
            + str(chi2_cut),
            fontsize=40,
        )

        pp.savefig()

    # close PDF figure
    pp.close()
コード例 #7
0
def simulate_av_plots(
    megabeast_input_file, log_scale=False, input_lognormal=None, input_lognormal2=None
):
    """
    Plot distributions of simulated AVs, and overplot the best fit lognormals

    Parameters
    ----------
    megabeast_input_file : string
        Name of the file that contains settings, filenames, etc

    log_scale : boolean (default=False)
        If True, make the histogram x-axis a log scale (to visualize log-normal
        A_V distribution)

    input_lognormal, input_lognormal2 : dict (default=None)
        Set these to the original values used to create the fake data, and they
        will also be plotted

    """

    # read in the settings from the file
    mb_settings = read_input(megabeast_input_file)

    # get the project name
    projectname = mb_settings["projectname"]

    # read in the beast data that is needed by all the pixels
    # *** this likely needs updating - probably will fail - see megabeast.py
    beast_data = read_sed_data(
        mb_settings["beast_seds_filename"],
        mb_settings["beast_noise_filename"],
        beast_params=["completeness", "Av"],
    )  # ,'Rv','f_A'])
    av_grid = np.unique(beast_data["Av"])

    # also make a more finely sampled A_V grid
    if not log_scale:
        av_grid_big = np.linspace(np.min(av_grid), np.max(av_grid), 500)
    else:
        av_grid_big = np.geomspace(np.min(av_grid), np.max(av_grid), 500)

    # read in the nstars image
    nstars_image, nstars_header = fits.getdata(
        mb_settings["nstars_filename"], header=True
    )
    # dimensions of images/plotting
    y_dimen = nstars_image.shape[0]
    x_dimen = nstars_image.shape[1]

    # read in the best fits
    label_list = mb_settings["fit_param_names"]
    best_fits = {}
    for label in label_list:
        with fits.open(
            "./"
            + projectname
            + "_megabeast/"
            + projectname
            + "_"
            + label
            + "_bestfit.fits"
        ) as hdu:
            best_fits[label] = hdu[0].data

    # set colors for plots
    cmap = matplotlib.cm.get_cmap("inferno")
    color_data = cmap(0.0)
    color_fit = cmap(0.5)
    if input_lognormal is not None:
        color_input = cmap(0.85)

    # -----------------
    # plotting
    # -----------------

    # set up figure
    fig = plt.figure(figsize=(x_dimen * 2, y_dimen * 2))

    for i in tqdm(range(y_dimen), desc="y pixels"):
        for j in tqdm(range(x_dimen), desc="x pixels"):
            # for i in [0]:
            #    for j in [12]:

            if nstars_image[i, j] > 20:

                # -------- data

                # read in the original lnp data
                lnp_filename = mb_settings["lnp_file_prefix"] + "_%i_%i_lnp.hd5" % (
                    j,
                    i,
                )
                lnp_data = read_lnp_data(lnp_filename, nstars_image[i, j])
                lnp_vals = np.array(lnp_data["vals"])

                # completeness for each of the values
                lnp_comp = beast_data["completeness"][lnp_data["indxs"]]

                # best A_V for each star
                best_av = []
                for k in range(lnp_vals.shape[1]):
                    vals = lnp_vals[:, k]
                    lnp_vals[:, k] = np.log(np.exp(vals) / np.sum(np.exp(vals)))
                    inds = lnp_data["indxs"][:, k]
                    best_val_ind = np.where(vals == np.max(vals))[0][0]
                    best_av.append(beast_data["Av"][inds[best_val_ind]])
                best_av = np.array(best_av)

                # stack up some representation of what's being maximized in ensemble_model.py
                prob_stack = np.sum(lnp_comp * np.exp(lnp_vals), axis=1)

                # normalize it (since it's not clear what the numbers mean anyway)
                # prob_stack = prob_stack / np.sum(prob_stack)
                prob_stack = prob_stack / np.trapz(prob_stack, av_grid)

                # stack up the probabilities at each A_V
                # prob_stack = np.sum(np.exp(lnp_vals), axis=1)

                # set up the subplot
                plt.subplot(y_dimen, x_dimen, (y_dimen - i - 1) * (x_dimen) + j + 1)

                # make a histogram
                if not log_scale:
                    plt.plot(
                        av_grid,
                        prob_stack,
                        marker=".",
                        ms=0,
                        mew=0,
                        linestyle="-",
                        color=color_data,
                        linewidth=4,
                    )
                if log_scale:
                    plt.plot(
                        np.log10(av_grid),
                        prob_stack,
                        marker=".",
                        ms=0,
                        mew=0,
                        linestyle="-",
                        color=color_data,
                        linewidth=4,
                    )

                ax = plt.gca()

                # -------- input lognormal(s)

                if input_lognormal is not None:

                    # create lognormal
                    lognorm = _lognorm(
                        av_grid_big,
                        input_lognormal["max_pos"],
                        input_lognormal["sigma"],
                        input_lognormal["N"],
                    )

                    # if there's a second lognormal
                    if input_lognormal2 is not None:
                        lognorm += _lognorm(
                            av_grid_big,
                            input_lognormal2["max_pos"],
                            input_lognormal2["sigma"],
                            input_lognormal2["N"],
                        )

                    # normalize it
                    # lognorm = lognorm / np.sum(lognorm)
                    lognorm = lognorm / np.trapz(lognorm, av_grid_big)

                    # plot it
                    # yrange_before = ax.get_ylim()
                    if not log_scale:
                        plt.plot(
                            av_grid_big,
                            lognorm,
                            marker=".",
                            ms=0,
                            mew=0,
                            linestyle="-",
                            color=color_input,
                            linewidth=2,
                            alpha=0.85,
                        )
                    if log_scale:
                        plt.plot(
                            np.log10(av_grid_big),
                            lognorm,
                            marker=".",
                            ms=0,
                            mew=0,
                            linestyle="-",
                            color=color_input,
                            linewidth=2,
                            alpha=0.85,
                        )
                    # ax.set_ylim(yrange_before)

                # -------- best fit

                # generate best fit
                lognorm = _two_lognorm(
                    av_grid_big,
                    best_fits["Av1"][i, j],
                    best_fits["Av2"][i, j],
                    sigma1=best_fits["sigma1"][i, j],
                    sigma2=best_fits["sigma2"][i, j],
                    N1=nstars_image[i, j]
                    * (1 - 1 / (best_fits["N12_ratio"][i, j] + 1)),
                    N2=nstars_image[i, j] / (best_fits["N12_ratio"][i, j] + 1),
                )

                # normalize it
                # lognorm = lognorm / nstars_image[i,j]
                # lognorm = lognorm / np.sum(lognorm)
                lognorm = lognorm / np.trapz(lognorm, av_grid_big)

                # plot it
                yrange_before = ax.get_ylim()
                if not log_scale:
                    plt.plot(
                        av_grid_big,
                        lognorm,
                        marker=".",
                        ms=0,
                        mew=0,
                        dashes=[3, 1.5],
                        color=color_fit,
                        linewidth=2,
                    )
                if log_scale:
                    plt.plot(
                        np.log10(av_grid_big),
                        lognorm,
                        marker=".",
                        ms=0,
                        mew=0,
                        dashes=[3, 1.5],
                        color=color_fit,
                        linewidth=2,
                    )
                ax.set_ylim(yrange_before)

    fig.add_subplot(111, frameon=False)
    plt.tick_params(labelcolor="none", top="off", bottom="off", left="off", right="off")
    plt.grid(False)
    if not log_scale:
        plt.xlabel(r"$A_V$", size=15)
    else:
        plt.xlabel(r"Log $A_V$", size=15)
    plt.ylabel("PDF", size=15)
    plt.tight_layout()

    # save figure
    plt.savefig("./" + projectname + "_megabeast/" + projectname + "_bestfit_plot.pdf")
    plt.close()