def test_flat_age_prior_weights():
    """
    Test for flat age prior
    """
    log_age = np.array([6.0, 7.0, 8.0, 9.0, 10.0])
    log_age_prior_model = {"name": "flat"}
    log_age_prior = compute_age_prior_weights(log_age, log_age_prior_model)
    expected_log_age_prior = [1, 1, 1, 1, 1]
    np.testing.assert_allclose(log_age_prior,
                               expected_log_age_prior,
                               err_msg=("Flat age prior error"))
Beispiel #2
0
def test_flat_log_age_prior_weights():
    """
    Test for flat log age prior
    """
    log_age = np.array([6.0, 7.0, 8.0, 9.0, 10.0])
    log_age_prior_model = {"name": "flat_log", "sfr": 1.0}
    log_age_prior = compute_age_prior_weights(log_age, log_age_prior_model)
    expected_log_age_prior = [2.2222e03, 2.2222e02, 2.2222e01, 2.2222e00, 2.2222e-01]
    np.testing.assert_allclose(
        log_age_prior, expected_log_age_prior, err_msg=("Flat log, log age prior error")
    )
def test_bins_histo_age_prior_weights():
    """
    Test for bin histogram age prior
    """
    log_age = np.array([7.0, 8.0, 9.0])
    log_age_prior_model = {
        "name": "bins_histo",
        "logages": [6.0, 7.0, 8.0, 9.0, 10.0],
        "values": [1.0, 2.0, 1.0, 5.0, 3.0],
    }
    log_age_prior = compute_age_prior_weights(log_age, log_age_prior_model)
    expected_log_age_prior = [0.75, 0.375, 1.875]
    np.testing.assert_allclose(
        log_age_prior,
        expected_log_age_prior,
        err_msg=("Bin histogram log age prior error"),
    )
def test_flat_log_age_prior_weights():
    """
    Test for flat log age prior
    """
    log_age = np.array([6.0, 7.0, 8.0, 9.0, 10.0])
    log_age_prior_model = {"name": "flat_log"}
    log_age_prior = compute_age_prior_weights(log_age, log_age_prior_model)
    expected_log_age_prior = [
        4.500045e00,
        4.500045e-01,
        4.500045e-02,
        4.500045e-03,
        4.500045e-04,
    ]
    np.testing.assert_allclose(log_age_prior,
                               expected_log_age_prior,
                               err_msg=("Flat log, log age prior error"))
def test_exp_age_prior_weights():
    """
    Test for exponential age prior with a tau = 0.1
    """
    log_age = np.array([6.0, 7.0, 8.0, 9.0, 10.0])
    log_age_prior_model = {"name": "exp", "tau": 0.1}
    log_age_prior = compute_age_prior_weights(log_age, log_age_prior_model)
    expected_log_age_prior = [
        2.18765367e00,
        1.99936491e00,
        8.12881110e-01,
        1.00317499e-04,
        8.22002849e-44,
    ]
    np.testing.assert_allclose(
        log_age_prior,
        expected_log_age_prior,
        err_msg=("Exponential log age prior error"),
    )
def test_bins_interp_age_prior_weights():
    """
    Test for bin interpolation age prior
    """
    log_age = np.array([6.0, 7.0, 8.0, 9.0, 10.0])
    log_age_prior_model = {
        "name": "bins_interp",
        "logages": [6.0, 7.0, 8.0, 9.0, 10.0],
        "values": [1.0, 2.0, 1.0, 5.0, 3.0],
    }
    log_age_prior = compute_age_prior_weights(log_age, log_age_prior_model)
    expected_log_age_prior = [
        0.41666667, 0.83333333, 0.41666667, 2.08333333, 1.25
    ]
    np.testing.assert_allclose(
        log_age_prior,
        expected_log_age_prior,
        err_msg=("Bin histogram log age prior error"),
    )
Beispiel #7
0
def test_exp_age_prior_weights():
    """
    Test for exponential age prior with a tau = 0.1
    """
    log_age = np.array([6.0, 7.0, 8.0, 9.0, 10.0])
    log_age_prior_model = {"name": "exp", "tau": 0.1}
    log_age_prior = compute_age_prior_weights(log_age, log_age_prior_model)
    expected_log_age_prior = [
        9.900498e-01,
        9.048374e-01,
        3.678794e-01,
        4.539993e-05,
        3.720076e-44,
    ]
    np.testing.assert_allclose(
        log_age_prior,
        expected_log_age_prior,
        err_msg=("Exponential log age prior error"),
        rtol=1e-6,
    )
Beispiel #8
0
def gen_SimObs_from_sedgrid(
    sedgrid,
    sedgrid_noisemodel,
    nsim=100,
    compl_filter="F475W",
    complcut=None,
    magcut=None,
    ranseed=None,
    vega_fname=None,
    weight_to_use="weight",
    age_prior_model=None,
    mass_prior_model=None,
):
    """
    Generate simulated observations using the physics and observation grids.
    The priors are sampled as they give the ensemble model for the stellar
    and dust distributions (IMF, Av distribution etc.).
    The physics model gives the SEDs based on the priors.
    The observation model gives the noise, bias, and completeness all of
    which are used in simulating the observations.

    Currently written to only work for the toothpick noisemodel.

    Parameters
    ----------
    sedgrid: grid.SEDgrid instance
        model grid

    sedgrid_noisemodel: beast noisemodel instance
        noise model data

    nsim : int
        number of observations to simulate

    compl_filter : str
        Filter to use for completeness (required for toothpick model).
        Set to max to use the max value in all filters.

    complcut : float (defualt=None)
        completeness cut for only including model seds above the cut
        where the completeness cut ranges between 0 and 1.

    magcut : float (defualt=None)
        faint-end magnitude cut for only including model seds brighter
        than the given magnitude in compl_filter.

    ranseed : int
        used to set the seed to make the results reproducable,
        useful for testing

    vega_fname : string
        filename for the vega info, useful for testing

    weight_to_use : string (default='weight')
        Set to either 'weight' (prior+grid), 'prior_weight', 'grid_weight',
        or 'uniform' (this option is valid only when nsim is supplied) to
        choose the weighting for SED selection.

    age_prior_model : dict
        age prior model in the BEAST dictonary format

    mass_prior_model : dict
        mass prior model in the BEAST dictonary format

    Returns
    -------
    simtable : astropy Table
        table giving the simulated observed fluxes as well as the
        physics model parmaeters
    """
    n_models, n_filters = sedgrid.seds.shape
    flux = sedgrid.seds

    # get the vega fluxes for the filters
    _, vega_flux, _ = Vega(source=vega_fname).getFlux(sedgrid.filters)

    # cache the noisemodel values
    model_bias = sedgrid_noisemodel["bias"]
    model_unc = np.fabs(sedgrid_noisemodel["error"])
    model_compl = sedgrid_noisemodel["completeness"]

    # only use models that have non-zero completeness in all filters
    # zero completeness means the observation model is not defined for that filters/flux
    ast_defined = model_compl > 0.0
    sum_ast_defined = np.sum(ast_defined, axis=1)
    goodobsmod = sum_ast_defined >= n_filters

    # completeness from toothpick model so n band completeness values
    # require only 1 completeness value for each model
    # max picked to best "simulate" how the photometry detection is done
    if compl_filter.lower() == "max":
        model_compl = np.max(model_compl, axis=1)
    else:
        short_filters = [
            filter.split(sep="_")[-1].upper() for filter in sedgrid.filters
        ]
        if compl_filter.upper() not in short_filters:
            raise NotImplementedError(
                "Requested completeness filter not present:" +
                compl_filter.upper() + "\nPossible filters:" +
                "\n".join(short_filters))

        filter_k = short_filters.index(compl_filter.upper())
        print("Completeness from %s" % sedgrid.filters[filter_k])
        model_compl = model_compl[:, filter_k]

    # if complcut is provided, only use models above that completeness cut
    # in addition to the non-zero completeness criterion
    if complcut is not None:
        goodobsmod = (goodobsmod) & (model_compl >= complcut)

    # if magcut is provided, only use models brighter than the magnitude cut
    # in addition to the non-zero completeness criterion
    if magcut is not None:
        fluxcut_compl_filter = 10**(-0.4 * magcut) * vega_flux[filter_k]
        goodobsmod = (goodobsmod) & (flux[:, filter_k] >= fluxcut_compl_filter)

    # initialize the random number generator
    rangen = default_rng(ranseed)

    # if the age and mass prior models are given, use them to determine the
    # total number of stars to simulate
    model_indx = np.arange(n_models)
    if (age_prior_model is not None) and (mass_prior_model is not None):
        nsim = 0
        # logage_range = [min(sedgrid["logA"]), max(sedgrid["logA"])]
        mass_range = [min(sedgrid["M_ini"]), max(sedgrid["M_ini"])]

        # compute the total mass and average mass of a star given the mass_prior_model
        nmass = 100
        masspts = np.logspace(np.log10(mass_range[0]), np.log10(mass_range[1]),
                              nmass)
        massprior = compute_mass_prior_weights(masspts, mass_prior_model)
        totmass = np.trapz(massprior, masspts)
        avemass = np.trapz(masspts * massprior, masspts) / totmass

        # compute the mass of the remaining stars at each age and
        # simulate the stars assuming everything is complete
        gridweights = sedgrid[weight_to_use]
        gridweights = gridweights / np.sum(gridweights)

        grid_ages = np.unique(sedgrid["logA"])
        ageprior = compute_age_prior_weights(grid_ages, age_prior_model)
        bin_boundaries = compute_bin_boundaries(grid_ages)
        bin_widths = np.diff(10**(bin_boundaries))
        totsim_indx = np.array([], dtype=int)
        for cage, cwidth, cprior in zip(grid_ages, bin_widths, ageprior):
            gmods = sedgrid["logA"] == cage
            cur_mass_range = [
                min(sedgrid["M_ini"][gmods]),
                max(sedgrid["M_ini"][gmods]),
            ]
            gmass = (masspts >= cur_mass_range[0]) & (masspts <=
                                                      cur_mass_range[1])
            curmasspts = masspts[gmass]
            curmassprior = massprior[gmass]
            totcurmass = np.trapz(curmassprior, curmasspts)

            # compute the mass remaining at each age -> this is the mass to simulate
            simmass = cprior * cwidth * totcurmass / totmass
            nsim_curage = int(round(simmass / avemass))

            # simluate the stars at the current age
            curweights = gridweights[gmods]
            curweights /= np.sum(curweights)
            cursim_indx = rangen.choice(model_indx[gmods],
                                        size=nsim_curage,
                                        p=curweights)

            totsim_indx = np.concatenate((totsim_indx, cursim_indx))

            nsim += nsim_curage
            # totsimcurmass = np.sum(sedgrid["M_ini"][cursim_indx])
            # print(cage, totcurmass / totmass, simmass, totsimcurmass, nsim_curage)

        totsimmass = np.sum(sedgrid["M_ini"][totsim_indx])
        print(f"number total simulated stars = {nsim}; mass = {totsimmass}")
        compl_choice = rangen.random(nsim)
        compl_indx = model_compl[totsim_indx] >= compl_choice
        sim_indx = totsim_indx[compl_indx]
        totcompsimmass = np.sum(sedgrid["M_ini"][sim_indx])
        print(
            f"number of simulated stars w/ completeness = {len(sim_indx)}; mass = {totcompsimmass}"
        )

    else:  # total number of stars to simulate set by command line input

        if weight_to_use == "uniform":
            # sample to get the indices of the picked models
            sim_indx = rangen.choice(model_indx[goodobsmod], nsim)

        else:
            gridweights = sedgrid[weight_to_use][goodobsmod] * model_compl[
                goodobsmod]
            gridweights = gridweights / np.sum(gridweights)

            # sample to get the indexes of the picked models
            sim_indx = rangen.choice(model_indx[goodobsmod],
                                     size=nsim,
                                     p=gridweights)

        print(f"number of simulated stars = {nsim}")

    # setup the output table
    ot = Table()
    qnames = list(sedgrid.keys())
    # simulated data
    for k, filter in enumerate(sedgrid.filters):
        simflux_wbias = flux[sim_indx, k] + model_bias[sim_indx, k]

        simflux = rangen.normal(loc=simflux_wbias,
                                scale=model_unc[sim_indx, k])

        bname = filter.split(sep="_")[-1].upper()
        fluxname = f"{bname}_FLUX"
        colname = f"{bname}_RATE"
        magname = f"{bname}_VEGA"
        ot[fluxname] = Column(simflux)
        ot[colname] = Column(ot[fluxname] / vega_flux[k])
        pindxs = ot[colname] > 0.0
        nindxs = ot[colname] <= 0.0
        ot[magname] = Column(ot[colname])
        ot[magname][pindxs] = -2.5 * np.log10(ot[colname][pindxs])
        ot[magname][nindxs] = 99.999

        # add in the physical model values in a form similar to
        # the output simulated (physics+obs models) values
        # useful if using the simulated data to interpolate ASTs
        #   (e.g. for MATCH)
        fluxname = f"{bname}_INPUT_FLUX"
        ratename = f"{bname}_INPUT_RATE"
        magname = f"{bname}_INPUT_VEGA"
        ot[fluxname] = Column(flux[sim_indx, k])
        ot[ratename] = Column(ot[fluxname] / vega_flux[k])
        pindxs = ot[ratename] > 0.0
        nindxs = ot[ratename] <= 0.0
        ot[magname] = Column(ot[ratename])
        ot[magname][pindxs] = -2.5 * np.log10(ot[ratename][pindxs])
        ot[magname][nindxs] = 99.999

    # model parmaeters
    for qname in qnames:
        ot[qname] = Column(sedgrid[qname][sim_indx])

    return ot
def compute_age_mass_metallicity_weights(_tgrid,
                                         indxs,
                                         age_prior_model={"name": "flat"},
                                         mass_prior_model={"name": "kroupa"},
                                         met_prior_model={"name": "flat"},
                                         **kwargs):
    """
    Computes the age-mass-metallicity grid and prior weights
    on the BEAST model spectra grid

    Keywords
    --------
    _tgrid : BEAST model spectra grid.

    age_prior_model: dict
        dict including prior model name and parameters
     mass_prior_model: dict
        dict including prior model name and parameters
     met_prior_model: dict
        dict including prior model name and parameters

    Returns
    -------
    Grid and prior weight columns updated by multiplying by the
       age-mass-metallicity weight.
    """

    # get the unique metallicities
    uniq_Zs = np.unique(_tgrid[indxs]["Z"])

    # setup the vector to hold the z weight vector
    total_z_grid_weight = np.zeros(len(uniq_Zs))
    total_z_prior_weight = np.zeros(len(uniq_Zs))
    total_z_weight = np.zeros(len(uniq_Zs))

    for az, z_val in enumerate(uniq_Zs):
        print("computing the age-mass-metallicity grid weight for Z = ", z_val)

        # get the grid for a single metallicity
        zindxs, = np.where(_tgrid[indxs]["Z"] == z_val)

        # get the unique ages for this metallicity
        zindxs = indxs[zindxs]
        uniq_ages = np.unique(_tgrid[zindxs]["logA"])

        # compute the age weights
        age_grid_weights = compute_age_grid_weights(uniq_ages)
        age_prior_weights = compute_age_prior_weights(uniq_ages,
                                                      age_prior_model)

        for ak, age_val in enumerate(uniq_ages):
            # get the grid for a single age
            aindxs, = np.where((_tgrid[indxs]["logA"] == age_val)
                               & (_tgrid[indxs]["Z"] == z_val))
            aindxs = indxs[aindxs]
            _tgrid_single_age = _tgrid[aindxs]

            # compute the mass weights
            if len(aindxs) > 1:
                cur_masses = _tgrid_single_age["M_ini"]
                mass_grid_weights = compute_mass_grid_weights(cur_masses)
                mass_prior_weights = compute_mass_prior_weights(
                    cur_masses, mass_prior_model)
            else:
                # must be a single mass for this age,z combination
                # set mass weight to zero to remove this point from the grid
                mass_grid_weights = np.zeros(1)
                mass_prior_weights = np.zeros(1)

            # apply both the mass and age weights
            for i, k in enumerate(aindxs):
                comb_grid_weights = mass_grid_weights[i] * age_grid_weights[ak]
                comb_prior_weights = mass_prior_weights[i] * age_prior_weights[
                    ak]
                _tgrid[k]["grid_weight"] *= comb_grid_weights
                _tgrid[k]["prior_weight"] *= comb_prior_weights
                _tgrid[k]["weight"] *= comb_grid_weights * comb_prior_weights

        # compute the current total weight at each metallicity
        total_z_grid_weight[az] = np.sum(_tgrid[zindxs]["grid_weight"])
        total_z_prior_weight[az] = np.sum(_tgrid[zindxs]["prior_weight"])
        total_z_weight[az] = np.sum(_tgrid[zindxs]["weight"])

    # ensure that the metallicity prior is uniform
    if len(uniq_Zs) > 1:
        # get the metallicity weights
        met_grid_weights = compute_metallicity_grid_weights(uniq_Zs)
        met_grid_weights /= np.sum(met_grid_weights)
        met_prior_weights = compute_metallicity_prior_weights(
            uniq_Zs, met_prior_model)
        met_prior_weights /= np.sum(met_prior_weights)
        met_weights = met_grid_weights * met_prior_weights

        # correct for any non-unformity in the number size of the
        # age-mass grids between metallicity points
        total_z_grid_weight /= np.sum(total_z_grid_weight)
        total_z_prior_weight /= np.sum(total_z_prior_weight)
        total_z_weight /= np.sum(total_z_weight)

        for i, z_val in enumerate(uniq_Zs):
            # get the grid for this metallicity
            zindxs, = np.where(_tgrid[indxs]["Z"] == z_val)
            zindxs = indxs[zindxs]
            _tgrid[zindxs]["grid_weight"] *= (met_grid_weights[i] *
                                              total_z_grid_weight[i])
            _tgrid[zindxs]["prior_weight"] *= (met_prior_weights[i] *
                                               total_z_prior_weight[i])
            _tgrid[zindxs]["weight"] *= met_weights[i] * total_z_weight[i]