Example #1
0
def test_weighted_quantile(seed=42):
    np.random.seed(seed)
    x = np.random.rand(25)
    q = np.arange(0.1, 1.0, 0.111234)
    a = corner.quantile(x, q, weights=np.ones_like(x))
    b = np.percentile(x, 100*np.array(q))
    assert np.allclose(a, b)

    q = [0.0, 1.0]
    a = corner.quantile(x, q, weights=np.random.rand(len(x)))
    assert np.allclose(a, (np.min(x), np.max(x)))
Example #2
0
def test_weighted_quantile(seed=42):
    np.random.seed(seed)
    x = np.random.rand(25)
    q = np.arange(0.1, 1.0, 0.111234)
    a = corner.quantile(x, q, weights=np.ones_like(x))
    b = np.percentile(x, 100 * np.array(q))
    assert np.allclose(a, b)

    q = [0.0, 1.0]
    a = corner.quantile(x, q, weights=np.random.rand(len(x)))
    assert np.allclose(a, (np.min(x), np.max(x)))
 def get_quantiles_width(self, param):
     quantiles = self.get_plotting_kwargs()['quantiles']
     s = self.posterior[[param]]
     original_qtles = corner.quantile(x=s, q=quantiles)
     self.reweighted_posterior = self.get_reweighted_posterior()
     rw = self.reweighted_posterior[[param]]
     reweighted_qtles = corner.quantile(x=rw, q=quantiles)
     return dict(
         original=original_qtles[1] - original_qtles[0],
         reweighted=reweighted_qtles[1] - reweighted_qtles[0],
     )
Example #4
0
def distill_factor_quantiles(correction_chain, pool=None):
    """Distill relativistic correction factor quantiles from a
    relativistic correction chain.

    Parameters
    ----------
    correction_chain : :class:`numpy.ndarray`
        Relativistic correction_chain chain.
    pool : :class:`multiprocessing.Pool` or None, optional
        Multiprocessing pool (default is `None`).

    Returns
    -------
    factor_quantiles : dict of dict of :class:`numpy.ndarray`
        Relativistic correction factor quantiles.

    """
    mapping = pool.imap if pool else map
    num_cpus = cpu_count() if pool else 1

    quantile_levels = [0.022750, 0.158655, 0.5, 0.841345, 0.977250]
    factor_quantiles = {0: {}, 2: {}}

    logger.info(
        "Distilling relativistic correction factors at redshift %.2f " +
        "with %i CPUs...\n", progrc.redshift, num_cpus)

    factor_chain = np.asarray(
        list(
            tqdm(mapping(compute_factors_from_corrections, correction_chain),
                 total=len(correction_chain),
                 mininterval=1,
                 file=sys.stdout)))

    factor_0_q = np.asarray([
        corner.quantile(factor_chain[:, 0, k_idx], q=quantile_levels)
        for k_idx, k in enumerate(wavenumbers)
    ])
    factor_2_q = np.asarray([
        corner.quantile(factor_chain[:, 1, k_idx], q=quantile_levels)
        for k_idx, k in enumerate(wavenumbers)
    ])

    for q_idx, q in enumerate([-2, -1, 0, 1, 2]):
        factor_quantiles[0][q] = factor_0_q[:, q_idx]
        factor_quantiles[2][q] = factor_2_q[:, q_idx]

    logger.info("... finished.\n")

    return factor_quantiles
Example #5
0
def save_chains(shifted_chains):
    """Save shifted parameter sample chains.

    Parameters
    ----------
    shifted_chains : :class:`numpy.ndarray`
        Shifted chains.

    """
    infile = PATHOUT / progrc.chain_file
    outfile = PATHOUT / progrc.chain_file.replace(".h5", "shifted.h5")

    with hp.File(infile, 'r') as indata, hp.File(outfile, 'w') as outdata:
        outdata.create_group('mcmc')
        outdata.create_dataset('mcmc/chain', data=shifted_chains)
        outdata.create_dataset('mcmc/autocorr_time',
                               data=indata['mcmc/autocorr_time'][()])

    logger.info("Verify medians: {}.\n".format(
        np.squeeze([
            np.squeeze(
                corner.quantile(shifted_param_chain.reshape(
                    -1, len(PARAMETERS[progrc.model])),
                                q=[0.5]))
            for shifted_param_chain in np.transpose(shifted_chains)
        ])))

    logger.info("Shifted chain saved to %s.\n", outfile)
Example #6
0
    def calculate_dndf_arrays(self, comp, smin=0.01, smax=1000, nsteps=1000,
                              qs=[0.16, 0.5, 0.84]):
        """ Calculate dnds for specified quantiles
        """

        template = self.nptf.templates_dict_nested[comp]['template']
        template_masked_compressed = \
            self.mask_and_compress(template, self.mask_total)

        self.template_sum = np.sum(template_masked_compressed)

        self.sarray = 10**np.linspace(np.log10(smin), np.log10(smax), nsteps)

        self.flux_array = self.sarray/self.exp_masked_mean

        self.data_array = np.array([self.dnds(comp, sample, self.sarray)
                                   for sample in self.nptf.samples])

        # Rescaling factor to convert dN/dS to [(ph /cm^2 /s)^-2 /deg^2]
        # Note that self.area_mask has units deg^2.
        rf = self.template_sum*self.exp_masked_mean/self.area_mask

        self.qArray = [corner.quantile(self.data_array[::, i], qs)
                       for i in range(len(self.sarray))]

        self.qmean = rf*np.array([np.mean(self.data_array[::, i])
                                  for i in range(len(self.sarray))])

        self.qlow = rf*np.array([q[0] for q in self.qArray])
        self.qmid = rf*np.array([q[1] for q in self.qArray])
        self.qhigh = rf*np.array([q[2] for q in self.qArray])
Example #7
0
    def plot_intensity_fraction_non_poiss(self, comp, smin=0.00001, smax=1000,
                                          nsteps=1000, qs=[0.16, 0.5, 0.84],
                                          bins=50, color='blue',
                                          ls_vert='dashed', *args, **kwargs):
        """ Plot flux fraction of a non-Poissonian template

            :param bins: flux fraction bins
            :param color_vert: colour of vertical quartile lines
            :param ls_vert: matplotlib linestyle of vertical quartile lines
            **kwargs: plotting options
        """

        flux_fraction_array_non_poiss = \
            np.array(self.return_intensity_arrays_non_poiss(comp, smin=smin,
                     smax=smax, nsteps=nsteps, counts=True))/self.total_counts

        frac_hist_comp, bin_edges_comp = \
            np.histogram(100*np.array(flux_fraction_array_non_poiss), bins=bins,
                         range=(0, 100))

        qs_comp = \
            corner.quantile(100*np.array(flux_fraction_array_non_poiss), qs)

        plt.plot(bin_edges_comp[:-1],
                 frac_hist_comp/float(sum(frac_hist_comp)),
                 color=color, *args, **kwargs)

        for q in qs_comp:
            plt.axvline(q, ls=ls_vert, color=color)
        self.qs_comp = qs_comp
Example #8
0
def get_95_exclusion(samps, param_index, num=int(1e7), weights=None):
    # NOTE: sigma_p MUST BE THE LAST VARIABLE
    import corner
    onesig, twosig = corner.quantile(samps[:, param_index],
                                     [0.68, 0.95],
                                     weights=weights)
    return twosig
Example #9
0
    def plot_intensity_fraction_poiss(self,
                                      comp,
                                      qs=[0.16, 0.5, 0.84],
                                      bins=50,
                                      color='blue',
                                      ls_vert='dashed',
                                      *args,
                                      **kwargs):
        """ Plot flux fraction of non-Poissonian component
        """

        flux_fraction_array_poiss = \
            np.array(self.return_intensity_arrays_poiss(comp, counts=True))\
            / self.total_counts

        frac_hist_comp, bin_edges_comp = \
            np.histogram(100*np.array(flux_fraction_array_poiss), bins=bins,
                         range=(0, 100))

        qs_comp = corner.quantile(100 * np.array(flux_fraction_array_poiss),
                                  qs)

        plt.plot(bin_edges_comp[:-1],
                 frac_hist_comp / float(sum(frac_hist_comp)),
                 color=color,
                 *args,
                 **kwargs)

        for q in qs_comp:
            plt.axvline(q, ls=ls_vert, color=color)
        self.qs_comp = qs_comp
Example #10
0
def get_hdr(prob, q=.68, weights=None):
    cond=prob > quantile(prob, q=1.-q, weights=weights)
    if any(cond):
        return cond, min(prob[cond])
    else:
        maximum=max(prob)
        cond=prob == maximum
        return cond, maximum
Example #11
0
def test_valid_quantile(seed=42):
    np.random.seed(seed)
    x = np.random.rand(25)
    q = np.arange(0.1, 1.0, 0.111234)

    a = corner.quantile(x, q)
    b = np.percentile(x, 100*q)
    assert np.allclose(a, b)
Example #12
0
def test_valid_quantile(seed=42):
    np.random.seed(seed)
    x = np.random.rand(25)
    q = np.arange(0.1, 1.0, 0.111234)

    a = corner.quantile(x, q)
    b = np.percentile(x, 100 * q)
    assert np.allclose(a, b)
Example #13
0
def get_val(samples, weights=None):

    # print("az", low, high)
    low, median, high = corner.quantile(samples, [0.16, 0.5, 0.84], weights)
    # print("corner", low, high)

    minus = median - low
    plus = high - median

    return (median, minus, plus)
def get_cq_mass(result):

    trace = result['chain']
    thin = 5
    trace = trace[:, ::thin, :]

    samples = trace.reshape(trace.shape[0] * trace.shape[1], trace.shape[2])
    cq_mass = corner.quantile(x=samples[:, 0], q=[0.16, 0.5, 0.84])

    return cq_mass
Example #15
0
def percentile_2d(a, q, w=None, axis=None):

    if w is None:
        return np.percentile(a, q, axis=axis)

    if axis is None:
        return corner.quantile(a.flat, np.array(q) / 100.0, w)

    if axis == 0:
        ret = np.empty((len(q), a.shape[1]))
        for i in range(a.shape[1]):
            ret[:, i] = corner.quantile(a[:, i], np.array(q) / 100.0, w)
    elif axis == 1:
        ret = np.empty((len(q), a.shape[0]))
        for i in range(a.shape[0]):
            ret[:, i] = corner.quantile(a[i, :], np.array(q) / 100.0, w)
    else:
        raise ValueError("axis must be 0 or 1 for 2d array")

    return ret
Example #16
0
def get_sfr(result, ms):

    trace = result['chain']
    thin = 5
    trace = trace[:, ::thin, :]

    samples = trace.reshape(trace.shape[0] * trace.shape[1], trace.shape[2])

    cq_age = corner.quantile(x=samples[:, 3], q=[0.16, 0.5, 0.84])
    cq_tau = corner.quantile(x=samples[:, 4], q=[0.16, 0.5, 0.84])

    age = cq_age[1] * 1e9  # Gyr to years
    logtau = cq_tau[1]
    tau = 10**logtau  # I think this is in years

    const = ms / (1 - np.exp(-1 * age / tau))
    prefac = const / tau

    sfr = prefac * np.exp(-1 * age / tau)

    return sfr
Example #17
0
def monteCarloTime(wdStar):
    """ Does a Monte-Carlo run of n=1000 to create a corner plot
    The images are saved as 'images/' + WhiteDwarfName + '_corner.png'
    """

    observed = wdStar.observedFluxes
    errors = wdStar.fluxErr
    checkLog = wdStar.logg
    arcsec = wdStar.parallax
    arcErr = wdStar.paraerr

    f = open('bergeronFlux.csv', 'r')
    bergTable = []
    for a in f:
        bergTable.append(list(map(float, a.split(', '))))
    teffR = []
    loggR = [7.0, 7.5, 8.0, 8.5, 9.0, 9.5]
    for a in bergTable:
        if not (a[0] in teffR):
            teffR.append(a[0])

    mrBerg = massRadiusTable()[0]
    massOrder = massRadiusTable()[1]

    nowThisisPodRacing = 500

    darthPlagueis = []
    start = time.time()
    for n in range(nowThisisPodRacing):
        print("On Monte-Carlo run: %i" % n)
        randCoeff = np.random.standard_normal((1, len(observed)))
        print("got the coeffs")
        randPara = np.random.standard_normal() * arcErr + arcsec
        randFluxMeasures = np.array(observed) + np.array(errors) * randCoeff
        darthPlagueis.append(
            monteGame(bergTable, mrBerg, massOrder, teffR, loggR,
                      randFluxMeasures, errors, wdStar.labels, arcsec) +
            [1 / randPara] + randFluxMeasures.tolist()[0])
    end = time.time()
    print("%i runs took %f seconds" % (nowThisisPodRacing, end - start))
    print("Building corner plot . . .")
    cornerData = darthPlagueis
    cornerLabel = ["Teff", "Log(g)", "Mass", "Radius", "Distance"
                   ] + wdStar.labels
    midichlorians = corner.corner(cornerData,
                                  labels=cornerLabel,
                                  show_titles=True,
                                  use_math_text=True,
                                  range=[1.0 for a in range(len(cornerLabel))])
    maceWindu = corner.quantile([a[3] for a in cornerData], [.25, .5, .75])
    print("Quantiles?: ")
    print(maceWindu)
    midichlorians.savefig('images/' + wdStar.name + '_corner.png')
def summarise_samples(
        samples: np.ndarray,
        quantiles: Optional[List[float]] = [0.16, 0.84]
):
    """Converts df to lower upper median valus
    :return: lower, upper, median
    """
    qtles = corner.quantile(x=samples, q=quantiles)
    return (
        qtles[0],
        qtles[1],
        np.median(samples)
    )
Example #19
0
def print_output_means(samples):
    """ Print the mean values obtained with the samples
    CURRENTLY ONLY WORKS FOR MODEL = 'AERI' OPTION

    Usage:
    print_output_means(samples)
    """

    if flag.model == 'aeri':
        lista = ['Mass ', 'W ', 't/tms', 'i']
        print(60 * '-')
        print('Output - mean values')
        for i in range(4):
            param = samples[:, i]
            param_mean = corner.quantile(param, 0.5, weights=None)
            param_sup = corner.quantile(param, 0.84, weights=None)
            param_inf = corner.quantile(param, 0.16, weights=None)
            print(lista[i], param_mean, ' +/- ', param_sup, param_inf)
            if i == 0:
                mass = param_mean
            if i == 1:
                oblat = 1 + 0.5 * (param_mean**2)  # Rimulo 2017
                oblat_sup = 1 + 0.5 * (param_sup**2)
                oblat_inf = 1 + 0.5 * (param_inf**2)
                print('Derived oblateness ', oblat, ' +/- ', oblat_sup,
                      oblat_inf)
            if i == 2:
                tms = param_mean

        Rpole, logL, _ = geneva_interp_fast(mass, oblat, tms, Zstr='014')
        beta_par = beta(oblat, is_ob=True)

        print('Equatorial radius: ', oblat * Rpole)
        print('Log Luminosity: ', logL)
        print('Beta: ', beta_par)

    return
Example #20
0
def result_string(x, weights=None, title_fmt=".2f", label=None):
    """

    :param x: marginalized 1-d posterior
    :param weights: weights of posteriors (optional)
    :param title_fmt: format to what digit the results are presented
    :param label: string of parameter label (optional)
    :return: string with mean \pm quartile
    """
    q_16, q_50, q_84 = quantile(x, [0.16, 0.5, 0.84], weights=weights)
    q_m, q_p = q_50 - q_16, q_84 - q_50

    # Format the quantile display.
    fmt = "{{0:{0}}}".format(title_fmt).format
    title = r"${{{0}}}_{{-{1}}}^{{+{2}}}$"
    title = title.format(fmt(q_50), fmt(q_m), fmt(q_p))
    if label is not None:
        title = "{0} = {1}".format(label, title)
    return title
Example #21
0
def test_dimension_mismatch(seed=42):
    np.random.seed(seed)
    corner.quantile(np.random.rand(100), [0.1, 0.5],
                    weights=np.random.rand(3))
Example #22
0
def test_invalid_quantiles_3(seed=42):
    np.random.seed(seed)
    corner.quantile(np.random.rand(100), [0.5, 1.0, 8.1])
Example #23
0
def collate_data(alldata,alldata_noagn):

    #### generate containers
    # photometry
    obs_phot, model_phot = {}, {}
    filters = ['wise_w1', 'wise_w2', 'wise_w3', 'wise_w4',
               'spitzer_irac_ch1','spitzer_irac_ch2','spitzer_irac_ch3','spitzer_irac_ch4']

    for f in filters: 
        obs_phot[f] = []
        model_phot[f] = []

    # model parameters
    z,objname = [], []
    model_pars = {}
    pnames = ['fagn', 'agn_tau']
    for p in pnames: 
        model_pars[p] = {'q50':[],'q84':[],'q16':[]}
    parnames = alldata[0]['pquantiles']['parnames']

    #### load information
    for dat in alldata:
        objname.append(dat['objname'])

        #### model parameters
        for key in model_pars.keys():
            model_pars[key]['q50'].append(dat['pquantiles']['q50'][parnames==key][0])
            model_pars[key]['q84'].append(dat['pquantiles']['q84'][parnames==key][0])
            model_pars[key]['q16'].append(dat['pquantiles']['q16'][parnames==key][0])

    #### X-ray information
    # match
    eparnames = alldata[0]['pextras']['parnames']
    xr_idx = eparnames == 'xray_lum'
    xray = prospector_io.load_xray_cat(xmatch = True)

    lsfr, lsfr_up, lsfr_down, xray_lum, xray_lum_err = [], [], [], [], []
    for i, dat in enumerate(alldata):
        idx = xray['objname'] == dat['objname']
        if idx.sum() != 1:
            print 1/0
        xflux = xray['flux'][idx][0]
        xflux_err = xray['flux_err'][idx][0]

        #### convert lumdist to redshift for distance calculations
        # only in alldata_noagn for now...
        lumdist = alldata_noagn[i]['residuals']['phot']['lumdist']
        zred = brentq(test_z, 0, 0.2, args=(lumdist), rtol=1.48e-08, maxiter=1000)
        z.append(zred)

        # flux is in ergs / cm^2 / s, convert to erg /s 
        pc2cm =  3.08568E18
        dfactor = 4*np.pi*(lumdist*1e6*pc2cm)**2
        xray_lum.append(xflux * dfactor)
        xray_lum_err.append(xflux_err * dfactor)

        ##### L_OBS / L_SFR(MODEL)
        # sample from the chain, assume gaussian errors for x-ray fluxes
        nsamp = 10000
        chain = dat['pextras']['flatchain'][:,xr_idx].squeeze()
        scale = xray_lum_err[-1]
        if scale <= 0:
            subchain =  np.repeat(xray_lum[-1], nsamp) / \
                        np.random.choice(chain,nsamp)
        else:
            subchain =  np.random.normal(loc=xray_lum[-1], scale=scale, size=nsamp) / \
                        np.random.choice(chain,nsamp)

        cent, eup, edo = quantile(subchain, [0.5, 0.84, 0.16])

        lsfr.append(cent)
        lsfr_up.append(eup)
        lsfr_down.append(edo)

    #### numpy arrays
    for key in model_pars.keys(): 
        for key2 in model_pars[key].keys():
            model_pars[key][key2] = np.array(model_pars[key][key2])

    out = {'pars':model_pars,'objname':objname}

    out['z'] = np.array(z)
    out['lsfr'] = np.array(lsfr)
    out['lsfr_up'] = np.array(lsfr_up)
    out['lsfr_down'] = np.array(lsfr_down)
    out['xray_luminosity'] = np.array(xray_lum)
    out['xray_luminosity_err'] = np.array(xray_lum_err)
    out['bpt'] = prospector_io.return_agn_str(np.ones_like(alldata,dtype=bool),string=True)

    return out
        # In [23]
        # Show samples in a triangle or corner plot
        theta_truth = np.array([run_params[i] 
                                for i in ['mass','logzsol','tau','tage','dust2']])
        theta_truth[0] = np.log10(theta_truth[0])


        fc = res['chain'][:,0::5,:]
        midx = [l=='mass' for l in varnames]
        #fc[:,midx] = np.log10(fc[:,midx])
        fc = np.matrix.transpose(fc)
        print(obs['objid'][i])

        for p00p, things in enumerate(fc):
            q_16, q_50, q_84 = quantile(things, [0.16, 0.5, 0.84],)
            q_m, q_p = q_50-q_16, q_84-q_50

            best = q_50
            low = q_m
            high = q_p
            table['low ' + varnames[p00p]][i] = low
            table['best ' + varnames[p00p]][i] = best
            table['high ' + varnames[p00p]][i] = high

            ws.write(i+1, p00p*3+1, low)
            ws.write(i+1, p00p*3+2, best)
            ws.write(i+1, p00p*3+3, high)


Example #25
0
def test_dimension_mismatch(seed=42):
    np.random.seed(seed)
    corner.quantile(np.random.rand(100), [0.1, 0.5], weights=np.random.rand(3))
Example #26
0
def test_invalid_quantiles_1(seed=42):
    np.random.seed(seed)
    corner.quantile(np.random.rand(100), [-0.1, 5])
Example #27
0
def test_invalid_quantiles_2(seed=42):
    np.random.seed(seed)
    corner.quantile(np.random.rand(100), 5)
Example #28
0
def MCMC_sample(walkers=100,
                steps=50000,
                lambS=lambS,
                lambA=lambA,
                lamb1=lamb1,
                d1=d1,
                mu1=mu1,
                diagnostics=False,
                storechain=True):
    """ Initiates MCMC sampling from the posterior defined above using given
    number of walkers for given number of steps. Returns sampler object which
    carries the resulting parameters for all walkers at all steps. """
    # get starting positions for each walker
    theta = [lambS, lambA, lamb1, d1, mu1]
    ndim = len(theta)
    startpos = [theta + 1e-1 * np.random.rand(ndim) for i in range(walkers)]
    burnin = 5000
    # set up sampler and run from the given starting positions
    sampler = emcee.EnsembleSampler(walkers, ndim, lnpost, threads=4)
    sampler.run_mcmc(startpos, steps)

    if diagnostics:  # make MCMC diagnostics plot
        try:
            autocorrest = np.round(sampler.acor, 2)
        except Exception:
            autocorrest = [np.nan, np.nan, np.nan, np.nan, np.nan]
        # plot simple diagnostics of a subset of walkers - time series and
        # autocorrelation for all parameters
        diagfig, axes = plt.subplots(len(theta), 2, figsize=(10, 17))
        titles = ['lambS', 'lambA', 'lamb1', 'd1', 'mu1']
        # get a random subset of walkers for plotting, say 15
        plotinds = np.random.choice(range(walkers), size=15, replace=False)
        diagfig.suptitle(
            """walkers = {}, steps = {}, autocorrelation times = {},
                            mean acceptance fraction = {}""".format(
                walkers, steps, autocorrest,
                np.mean(sampler.acceptance_fraction)))
        for f in range(len(theta)):
            for p in range(len(plotinds)):
                timeseries = sampler.chain[plotinds[p], :, f]
                axes[f][0].set_title('time series {}'.format(titles[f]))
                axes[f][0].plot(timeseries)
                axes[len(theta) - 1][0].set_xlabel('index')

                axes[f][1].set_title('autocorrelation {}'.format(titles[f]))
                maxlag = min(steps, 200)
                axes[f][1].plot(range(maxlag), autocorr(timeseries, maxlag))
                axes[len(theta) - 1][1].set_xlabel('lag')
        pylab.savefig('figures/diagnostics_{}_{}_C.pdf'.format(
            walkers, (steps)))

    # find n-dim MAP and 1D credible regions at 5% level
    lambS_MAP, lambA_MAP, lamb1_MAP, d1_MAP, mu1_MAP = \
        find_MAP(sampler.flatchain, sampler.flatlnprobability)
    lambS_5min, lambS_5max = find_CR(sampler.flatchain[:, 0],
                                     sampler.flatlnprobability)
    lambA_5min, lambA_5max = find_CR(sampler.flatchain[:, 1],
                                     sampler.flatlnprobability)
    lamb1_5min, lamb1_5max = find_CR(sampler.flatchain[:, 2],
                                     sampler.flatlnprobability)
    d1_5min, d1_5max = find_CR(sampler.flatchain[:, 3],
                               sampler.flatlnprobability)
    mu1_5min, mu1_5max = find_CR(sampler.flatchain[:, 4],
                                 sampler.flatlnprobability)

    # plot corner plot for orientation
    samples = sampler.chain[:, burnin:, :].reshape((-1, ndim))
    fig = corner.corner(samples,
                        labels=[
                            r'$\lambda_S$ (1/day)', r'$\lambda_A$ (1/day)',
                            r'$\lambda_1$ (1/day)', r'$d_1$ (1/day)',
                            r'$\mu_1$ (1/day)'
                        ],
                        truths=[[lambS_MAP], [lambA_MAP], [lamb1_MAP],
                                [d1_MAP], [mu1_MAP]],
                        truth_color=['firebrick'],
                        hist_kwargs={
                            'color': 'darkgrey',
                            'histtype': 'stepfilled'
                        })
    fig.savefig("figures/MCMC_{}_{}.pdf".format(walkers, steps))
    plt.close()

    # print stuff to file
    datafile = open('figures/results_{}_{}.txt'.format(walkers, steps), 'w')

    datafile.write('1) n-dimensional MAP values \n \n')
    datafile.write('{} \n{} \n{} \n{} \n{} \n \n'.format(
        lambS_MAP, lambA_MAP, lamb1_MAP, d1_MAP, mu1_MAP))

    datafile.write('2) 0.05 credibility boundaries in 1-dim \n \n')
    datafile.write('{}, {} \n'.format(lambS_5min, lambS_5max))
    datafile.write('{}, {} \n'.format(lambA_5min, lambA_5max))
    datafile.write('{}, {} \n'.format(lamb1_5min, lamb1_5max))
    datafile.write('{}, {} \n'.format(d1_5min, d1_5max))
    datafile.write('{}, {} \n \n'.format(mu1_5min, mu1_5max))

    datafile.write('3) quantiles: (2.5, 16, 50, 84, 97.5)\n')
    datafile.write('{}\n'.format(
        corner.quantile(samples[:, 0], [0.025, 0.16, 0.5, 0.84, 0.975])))
    datafile.write('{}\n'.format(
        corner.quantile(samples[:, 1], [0.025, 0.16, 0.5, 0.84, 0.975])))
    datafile.write('{}\n'.format(
        corner.quantile(samples[:, 2], [0.025, 0.16, 0.5, 0.84, 0.975])))
    datafile.write('{}\n'.format(
        corner.quantile(samples[:, 3], [0.025, 0.16, 0.5, 0.84, 0.975])))
    datafile.write('{}\n'.format(
        corner.quantile(samples[:, 4], [0.025, 0.16, 0.5, 0.84, 0.975])))
    datafile.close()

    # plot profile posteriors
    lnprobs = sampler.flatlnprobability
    sampled_params = sampler.flatchain
    best_post = np.max(lnprobs)

    fig, axes = plt.subplots(1, 5, sharey=True, figsize=(15, 3))

    axes[0].plot(sampled_params[:, 0],
                 -(lnprobs - best_post),
                 'o',
                 color='darkgrey',
                 alpha=0.1,
                 rasterized=True)
    axes[0].set_title(r'$\lambda_S = {}_{{\ {}}}^{{\ {}}}$ 1/day'.format(
        r2(lambS_MAP), r2(lambS_5min), r2(lambS_5max)))
    axes[0].axhline(y=3, xmin=0., xmax=1, color='firebrick')
    axes[0].set_ylim([-0.2, 5])
    axes[0].set_ylabel(r'log $p(\theta | \theta_{{MAP}})$')

    axes[1].plot(sampled_params[:, 1],
                 -(lnprobs - best_post),
                 'o',
                 color='darkgrey',
                 alpha=0.1,
                 rasterized=True)
    axes[1].set_title(r'$\lambda_A = {}_{{\ {}}}^{{\ {}}}$ 1/day'.format(
        r2(lambA_MAP), r2(lambA_5min), r2(lambA_5max)))
    axes[1].set_xlim([0, 2])
    axes[1].axhline(y=3, xmin=0., xmax=1, color='firebrick')

    axes[2].plot(sampled_params[:, 2],
                 -(lnprobs - best_post),
                 'o',
                 color='darkgrey',
                 alpha=0.1,
                 rasterized=True)
    axes[2].set_title(r'$\lambda_1 = {}_{{\ {}}}^{{\ {}}}$ 1/day'.format(
        r2(lamb1_MAP), r2(lamb1_5min), r2(lamb1_5max)))
    axes[2].set_xlabel('rates (1/day)')
    axes[2].axhline(y=3, xmin=0., xmax=1, color='firebrick')
    axes[2].set_xlim([-0.05, 2.2])

    axes[3].plot(sampled_params[:, 3],
                 -(lnprobs - best_post),
                 'o',
                 color='darkgrey',
                 alpha=0.1,
                 rasterized=True)
    axes[3].set_title(r'$d_1 = {}_{{\ {}}}^{{\ {}}}$ 1/day'.format(
        r2(d1_MAP), r2(d1_5min), r2(d1_5max)))
    axes[3].axhline(y=3, xmin=0., xmax=1, color='firebrick')
    axes[3].set_xlim([-0.05, 2.2])

    axes[4].plot(sampled_params[:, 4],
                 -(lnprobs - best_post),
                 'o',
                 color='darkgrey',
                 alpha=0.1,
                 rasterized=True)
    axes[4].set_title(r'$\mu_1 = {}_{{\ {}}}^{{\ {}}}$ 1/day'.format(
        r2(mu1_MAP), r2(mu1_5min), r2(mu1_5max)))
    axes[4].set_xlim([-0.05, 0.7])
    axes[4].axhline(y=3, xmin=0., xmax=1, color='firebrick')

    pylab.savefig('figures/post_profiles_{}_{}.pdf'.format(walkers, steps),
                  bbox_inches='tight')
    plt.close()

    if storechain:  # writes a the whole chain (if chain is short) or a
        # randomly subsampled set of size Nchain to a dataframe and then to a
        # .h5 file. Seperately saves the best found values (MLE).
        Nchain = 20000
        burned_lnP = sampler.flatlnprobability[burnin:]
        burned_params = sampler.flatchain[burnin:]
        chainlength = len(burned_lnP)
        if chainlength > Nchain:
            ch_indxs = np.random.choice(np.arange(chainlength),
                                        Nchain,
                                        replace=False)
            ch_probs = burned_lnP[ch_indxs]
            ch_params = burned_params[ch_indxs]
            print('long chain subsampled')
        else:
            ch_probs = burned_lnP
            ch_params = burned_params
            print('whole (short) chain stored')

        # create dataframe from this
        df = pd.DataFrame({
            'lnPosterior': ch_probs,
            'lambS': [x[0] for x in ch_params],
            'lambA': [x[1] for x in ch_params],
            'lamb1': [x[2] for x in ch_params],
            'd1': [x[3] for x in ch_params],
            'mu1': [x[4] for x in ch_params]
        })

        # mini dataframe for the MLE values
        dfMLE = pd.DataFrame({
            'lnPosterior': [best_post],
            'lambS': [lambS_MAP],
            'lambA': [lambA_MAP],
            'lamb1': [lamb1_MAP],
            'd1': [d1_MAP],
            'mu1': [mu1_MAP]
        })

        # mini dataframe for the upper 0.05 bounds
        dfUPPER = pd.DataFrame({
            'lambS': [lambS_5max],
            'lambA': [lambA_5max],
            'lamb1': [lamb1_5max],
            'd1': [d1_5max],
            'mu1': [mu1_5max]
        })

        # mini dataframe for the lower 0.05 bounds
        dfLOWER = pd.DataFrame({
            'lambS': [lambS_5min],
            'lambA': [lambA_5min],
            'lamb1': [lamb1_5min],
            'd1': [d1_5min],
            'mu1': [mu1_5min]
        })

        df.to_hdf('figures/chain_{}_{}.h5'.format(walkers, steps),
                  key='MCMC',
                  mode='w')
        dfMLE.to_hdf('figures/chain_{}_{}.h5'.format(walkers, steps),
                     key='MLE',
                     mode='r+')
        dfUPPER.to_hdf('figures/chain_{}_{}.h5'.format(walkers, steps),
                       key='UPPER',
                       mode='r+')
        dfLOWER.to_hdf('figures/chain_{}_{}.h5'.format(walkers, steps),
                       key='LOWER',
                       mode='r+')

    return sampler
def main(field, galaxy_seq):

    #vers = (np.__version__, scipy.__version__, h5py.__version__, fsps.__version__, prospect.__version__)
    #print("Numpy: {}\nScipy: {}\nH5PY: {}\nFSPS: {}\nProspect: {}".format(*vers))

    # -------------- Decide field and filters
    # Read in catalog from Lou
    if 'North' in field:
        df = pandas.read_pickle(adap_dir + 'GOODS_North_SNeIa_host_phot.pkl')

        all_filters = [
            'LBC_U_FLUX', 'ACS_F435W_FLUX', 'ACS_F606W_FLUX', 'ACS_F775W_FLUX',
            'ACS_F814W_FLUX', 'ACS_F850LP_FLUX', 'WFC3_F105W_FLUX',
            'WFC3_F125W_FLUX', 'WFC3_F140W_FLUX', 'WFC3_F160W_FLUX',
            'MOIRCS_K_FLUX', 'CFHT_Ks_FLUX', 'IRAC_CH1_SCANDELS_FLUX',
            'IRAC_CH2_SCANDELS_FLUX', 'IRAC_CH3_FLUX', 'IRAC_CH4_FLUX'
        ]

        #all_filters = ['LBC_U_FLUX', 'ACS_F435W_FLUX', 'ACS_F606W_FLUX',
        #'ACS_F775W_FLUX', 'ACS_F850LP_FLUX']

        seq = np.array(df['ID'])
        i = int(np.where(seq == galaxy_seq)[0])

    elif 'South' in field:
        df = pandas.read_pickle(adap_dir + 'GOODS_South_SNeIa_host_phot.pkl')

        all_filters = [
            'CTIO_U_FLUX', 'ACS_F435W_FLUX', 'ACS_F606W_FLUX',
            'ACS_F775W_FLUX', 'ACS_F814W_FLUX', 'ACS_F850LP_FLUX',
            'WFC3_F098M_FLUX', 'WFC3_F105W_FLUX', 'WFC3_F125W_FLUX',
            'WFC3_F160W_FLUX', 'HAWKI_KS_FLUX', 'IRAC_CH1_FLUX',
            'IRAC_CH2_FLUX', 'IRAC_CH3_FLUX', 'IRAC_CH4_FLUX'
        ]

        #all_filters = ['CTIO_U_FLUX', 'ACS_F435W_FLUX', 'ACS_F606W_FLUX',
        #'ACS_F775W_FLUX', 'ACS_F850LP_FLUX']

        seq = np.array(df['Seq'])
        i = int(np.where(seq == galaxy_seq)[0])

    #print('Read in pickle with the following columns:')
    #print(df.columns)
    #print('Rows in DataFrame:', len(df)

    print("Match index:", i, "for Seq:", galaxy_seq)

    # -------------- Preliminary stuff
    # Set up for emcee
    nwalkers = 1000
    niter = 500

    ndim = 12

    # Other set up
    obj_ra = df['RA'][i]
    obj_dec = df['DEC'][i]

    obj_z = df['zbest'][i]

    print("Object redshift:", obj_z)
    age_at_z = astropy_cosmo.age(obj_z).value
    print("Age of Universe at object redshift [Gyr]:", age_at_z)

    # ------------- Get obs data
    fluxes = []
    fluxes_unc = []
    useable_filters = []

    for ft in range(len(all_filters)):
        filter_name = all_filters[ft]

        flux = df[filter_name][i]
        fluxerr = df[filter_name + 'ERR'][i]

        if np.isnan(flux):
            continue

        if flux <= 0.0:
            continue

        if (fluxerr < 0) or np.isnan(fluxerr):
            fluxerr = 0.1 * flux

        fluxes.append(flux)
        fluxes_unc.append(fluxerr)
        useable_filters.append(filter_name)

    #print("\n")
    #print(df.loc[i])
    #print(fluxes, len(fluxes))
    #print(useable_filters, len(useable_filters))

    fluxes = np.array(fluxes)
    fluxes_unc = np.array(fluxes_unc)

    # Now build the prospector observation
    obs = build_obs(fluxes, fluxes_unc, useable_filters)

    # Set params for run
    run_params = {}
    run_params["object_redshift"] = obj_z
    run_params["fixed_metallicity"] = None
    run_params["add_duste"] = True
    #run_params["dust_type"] = 4

    run_params["zcontinuous"] = 1

    # Generate the model SED at the initial value of theta
    #theta = model.theta.copy()
    #initial_spec, initial_phot, initial_mfrac = model.sed(theta, obs=obs, sps=sps)

    verbose = True
    run_params["verbose"] = verbose

    model = build_model(**run_params)
    print("\nInitial free parameter vector theta:\n  {}\n".format(model.theta))
    #print("Initial parameter dictionary:\n{}".format(model.params))
    print("\n----------------------- Model details: -----------------------")
    print(model)
    print(
        "----------------------- End model details. -----------------------\n")

    # Here we will run all our building functions
    obs = build_obs(fluxes, fluxes_unc, useable_filters)
    sps = build_sps(**run_params)

    #plot_data(obs)
    #sys.exit(0)

    # --- start fitting ----
    # Set this to False if you don't want to do another optimization
    # before emcee sampling (but note that the "optimization" entry
    # in the output dictionary will be (None, 0.) in this case)
    # If set to true then another round of optmization will be performed
    # before sampling begins and the "optmization" entry of the output
    # will be populated.
    """
    run_params["optimize"] = False
    run_params["min_method"] = 'lm'
    # We'll start minimization from "nmin" separate places, 
    # the first based on the current values of each parameter and the 
    # rest drawn from the prior.  Starting from these extra draws 
    # can guard against local minima, or problems caused by 
    # starting at the edge of a prior (e.g. dust2=0.0)
    run_params["nmin"] = 5

    run_params["emcee"] = True
    run_params["dynesty"] = False
    # Number of emcee walkers
    run_params["nwalkers"] = nwalkers
    # Number of iterations of the MCMC sampling
    run_params["niter"] = niter
    # Number of iterations in each round of burn-in
    # After each round, the walkers are reinitialized based on the 
    # locations of the highest probablity half of the walkers.
    run_params["nburn"] = [8, 16, 32, 64]
    run_params["progress"] = True

    hfile = adap_dir + "emcee_" + field + "_" + str(galaxy_seq) + ".h5"

    if not os.path.isfile(hfile):

        print("Now running with Emcee.")
        output = fit_model(obs, model, sps, lnprobfn=lnprobfn, **run_params)

        print('done emcee in {0}s'.format(output["sampling"][1]))
    
        writer.write_hdf5(hfile, run_params, model, obs,
                          output["sampling"][0], output["optimization"][0],
                          tsample=output["sampling"][1],
                          toptimize=output["optimization"][1])
    
        print('Finished with Seq: ' + str(galaxy_seq))

    """

    hfile = adap_dir + "dynesty_" + field + "_" + str(galaxy_seq) + ".h5"

    if not os.path.isfile(hfile):

        print("Now running with Dynesty.")

        run_params["emcee"] = False
        run_params["dynesty"] = True
        run_params["nested_method"] = "rwalk"
        run_params["nlive_init"] = 400
        run_params["nlive_batch"] = 200
        run_params["nested_dlogz_init"] = 0.05
        run_params["nested_posterior_thresh"] = 0.05
        run_params["nested_maxcall"] = int(1e6)

        #from multiprocessing import Pool
        #with Pool(6) as pool:
        #    run_params["pool"] = pool
        output = fit_model(obs, model, sps, lnprobfn=lnprobfn, **run_params)
        print('done dynesty in {0}s'.format(output["sampling"][1]))

        writer.write_hdf5(hfile,
                          run_params,
                          model,
                          obs,
                          output["sampling"][0],
                          output["optimization"][0],
                          tsample=output["sampling"][1],
                          toptimize=output["optimization"][1])

        print('Finished with Seq: ' + str(galaxy_seq))

    # -------------------------
    # Visualizing results

    results_type = "dynesty"  # "emcee" | "dynesty"
    # grab results (dictionary), the obs dictionary, and our corresponding models
    # When using parameter files set `dangerous=True`
    #result, obs, _ = reader.results_from("{}_" + str(galaxy_seq) + \
    #                 ".h5".format(results_type), dangerous=False)

    result, obs, _ = reader.results_from(adap_dir + results_type + "_" + \
                     field + "_" + str(galaxy_seq) + ".h5", dangerous=False)

    #The following commented lines reconstruct the model and sps object,
    # if a parameter file continaing the `build_*` methods was saved along with the results
    #model = reader.get_model(result)
    #sps = reader.get_sps(result)

    # let's look at what's stored in the `result` dictionary
    print(result.keys())

    parnames = np.array(result['theta_labels'])
    print('Parameters in this model:', parnames)
    """
    if results_type == "emcee":

        chosen = np.random.choice(result["run_params"]["nwalkers"], size=150, replace=False)
        tracefig = reader.traceplot(result, figsize=(10,6), chains=chosen)

        tracefig.savefig(adap_dir + 'trace_' + field + '_' + str(galaxy_seq) + '.pdf', 
            dpi=200, bbox_inches='tight')

    else:
        tracefig = reader.traceplot(result, figsize=(10,6))
        tracefig.savefig(adap_dir + 'trace_' + field + '_' + str(galaxy_seq) + '.pdf', 
            dpi=200, bbox_inches='tight')
    """

    # Get chain for corner plot
    if results_type == 'emcee':
        trace = result['chain']
        thin = 5
        trace = trace[:, ::thin, :]
        samples = trace.reshape(trace.shape[0] * trace.shape[1],
                                trace.shape[2])
    else:
        samples = result['chain']

    # Plot SFH
    #print(model) # to copy paste agebins from print out

    nagebins = 6
    agebins = np.array([[0., 8.], [8., 8.47712125], [8.47712125, 9.],
                        [9., 9.47712125], [9.47712125, 9.77815125],
                        [9.77815125, 10.13353891]])

    # Get the zfractions from corner quantiles
    zf1 = corner.quantile(samples[:, 2], q=[0.16, 0.5, 0.84])
    zf2 = corner.quantile(samples[:, 3], q=[0.16, 0.5, 0.84])
    zf3 = corner.quantile(samples[:, 4], q=[0.16, 0.5, 0.84])
    zf4 = corner.quantile(samples[:, 5], q=[0.16, 0.5, 0.84])
    zf5 = corner.quantile(samples[:, 6], q=[0.16, 0.5, 0.84])

    zf_arr = np.array([zf1[1], zf2[1], zf3[1], zf4[1], zf5[1]])
    zf_arr_l = np.array([zf1[0], zf2[0], zf3[0], zf4[0], zf5[0]])
    zf_arr_u = np.array([zf1[2], zf2[2], zf3[2], zf4[2], zf5[2]])

    cq_mass = corner.quantile(samples[:, 7], q=[0.16, 0.5, 0.84])

    print("Total mass:", "{:.3e}".format(cq_mass[1]), "+",
          "{:.3e}".format(cq_mass[2] - cq_mass[1]), "-",
          "{:.3e}".format(cq_mass[1] - cq_mass[0]))
    # -----------

    new_agebins = pt.zred_to_agebins(zred=obj_z, agebins=agebins)
    print("New agebins:", new_agebins)

    # -----------------------
    # now convert to sfh and its errors
    sfr = pt.zfrac_to_sfr(total_mass=cq_mass[1],
                          z_fraction=zf_arr,
                          agebins=new_agebins)
    sfr_l = pt.zfrac_to_sfr(total_mass=cq_mass[1],
                            z_fraction=zf_arr_l,
                            agebins=new_agebins)
    sfr_u = pt.zfrac_to_sfr(total_mass=cq_mass[1],
                            z_fraction=zf_arr_u,
                            agebins=new_agebins)

    print("----------")
    print("z fractions:      ", zf_arr)
    print("Lower z fractions:", zf_arr_l)
    print("Upper z fractions:", zf_arr_u)

    print("----------")
    print("Inferred SFR:  ", sfr)
    print("Lower sfr vals:", sfr_l)
    print("Upper sfr vals:", sfr_u)

    #############
    x_agebins = 10**new_agebins / 1e9

    fig = plt.figure(figsize=(8, 4))
    ax = fig.add_subplot(111)

    ax.set_xlabel(r'$\mathrm{Time\, [Gyr];\, since\ galaxy\ formation}$',
                  fontsize=20)
    ax.set_ylabel(r'$\mathrm{SFR\, [M_\odot/yr]}$', fontsize=20)

    for a in range(len(agebins)):
        ax.plot(x_agebins[a],
                np.ones(len(x_agebins[a])) * sfr[a],
                color='mediumblue',
                lw=3.0)
        # put in some poisson errors
        sfr_err = np.ones(len(x_agebins[a])) * np.sqrt(sfr[a])
        sfr_plt = np.ones(len(x_agebins[a])) * sfr[a]
        sfr_low_fill = sfr_plt - sfr_err
        sfr_up_fill = sfr_plt + sfr_err
        ax.fill_between(x_agebins[a],
                        sfr_low_fill,
                        sfr_up_fill,
                        color='gray',
                        alpha=0.6)

    #ax.set_ylim(np.min(sfr_low_fill) * 0.3, np.max(sfr_up_fill) * 1.1)

    fig.savefig(adap_dir + 'sfh_' + field + '_' + str(galaxy_seq) + '.pdf',
                dpi=200,
                bbox_inches='tight')

    sys.exit(0)

    # Keep code block for future use if needed
    """
    # combination of linear and log axis from
    # https://stackoverflow.com/questions/21746491/combining-a-log-and-linear-scale-in-matplotlib

    # linear part i.e., first age bin
    ax.plot(x_agebins[0], np.ones(len(x_agebins[0])) * sfr[0], color='mediumblue', lw=3.5)
    #ax.fill_between(x_agebins[0], np.ones(len(x_agebins[0])) * sfr_l[0], 
    #               np.ones(len(x_agebins[0])) * sfr_u[0], color='gray', alpha=0.5)

    ax.set_xlim(0.0, 8.0)
    ax.spines['right'].set_visible(False)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_label_coords(x=1.2,y=-0.06)

    # now log axis 
    divider = make_axes_locatable(ax)
    axlog = divider.append_axes("right", size=3.0, pad=0, sharey=ax)
    axlog.set_xscale('log')

    for a in range(1, nagebins):
        axlog.plot(x_agebins[a], np.ones(len(x_agebins[a])) * sfr[a], color='mediumblue', lw=3.5)
        #axlog.fill_between(x_agebins[a], np.ones(len(x_agebins[a])) * sfr_l[a], 
        #                np.ones(len(x_agebins[a])) * sfr_u[a], color='gray', alpha=0.5)

    axlog.set_xlim(8.0, x_agebins[-1, -1] + 0.1)
    axlog.spines['left'].set_visible(False)
    axlog.yaxis.set_ticks_position('right')
    axlog.tick_params(labelright=False)

    axlog.xaxis.set_ticks(ticks=[8.0, 9.0])
    axlog.xaxis.set_ticks(ticks=[8.2, 8.4, 8.6, 8.8, 9.2, 9.4, 9.6], minor=True)

    axlog.set_xticklabels(['8', '9'])
    axlog.set_xticklabels(labels=[], minor=True)

    fig.savefig(adap_dir + 'sfh_' + field + '_' + str(galaxy_seq) + '.pdf',
        dpi=200, bbox_inches='tight')
    """

    # ---------- corner plot
    # set up corner ranges and labels
    math_parnames = [
        r'$\mathrm{log(Z_\odot)}$', r'$\mathrm{dust2}$', r'$zf_1$', r'$zf_2$',
        r'$zf_3$', r'$zf_4$', r'$zf_5$', r'$\mathrm{M_s}$',
        r'$\mathrm{f_{agn}}$', r'$\mathrm{agn_\tau}$',
        r'$\mathrm{dust_{ratio}}$', r'$\mathrm{dust_{index}}$'
    ]

    #math_parnames = [r'$\mathrm{M_s}$', r'$\mathrm{log(Z_\odot)}$',
    #r'$\mathrm{dust2}$', r'$\mathrm{t_{age}}$', r'$\mathrm{log(\tau)}$']

    # Fix labels for corner plot and
    # Figure out ranges for corner plot
    corner_range = []
    for d in range(ndim):

        # Get corner estimate and errors
        cq = corner.quantile(x=samples[:, d], q=[0.16, 0.5, 0.84])

        low_err = cq[1] - cq[0]
        up_err = cq[2] - cq[1]

        # Decide the padding for the plot range
        # depending on how large the error is relative
        # to the central estimate.
        if low_err * 2.5 >= cq[1]:
            sigma_padding_low = 1.2
        else:
            sigma_padding_low = 3.0

        if up_err * 2.5 >= cq[1]:
            sigma_padding_up = 1.2
        else:
            sigma_padding_up = 3.0

        low_lim = cq[1] - sigma_padding_low * low_err
        up_lim = cq[1] + sigma_padding_up * up_err

        corner_range.append((low_lim, up_lim))

        # Print estimate to screen
        if 'mass' in parnames[d]:
            pn = '{:.3e}'.format(cq[1])
            pnu = '{:.3e}'.format(up_err)
            pnl = '{:.3e}'.format(low_err)
        else:
            pn = '{:.3f}'.format(cq[1])
            pnu = '{:.3f}'.format(up_err)
            pnl = '{:.3f}'.format(low_err)

        print(parnames[d], ":  ", pn, "+", pnu, "-", pnl)

    # Corner plot
    cornerfig = corner.corner(samples,
                              quantiles=[0.16, 0.5, 0.84],
                              labels=math_parnames,
                              label_kwargs={"fontsize": 14},
                              range=corner_range,
                              smooth=0.5,
                              smooth1d=0.5)

    # loop over all axes *again* and set title
    # because it won't let me set the
    # format for soem titles separately
    # Looping has to be done twice because corner
    # plotting has to be done to get the figure.
    corner_axes = np.array(cornerfig.axes).reshape((ndim, ndim))

    for d in range(ndim):
        # Get corner estimate and errors
        cq = corner.quantile(x=samples[:, d], q=[0.16, 0.5, 0.84])

        low_err = cq[1] - cq[0]
        up_err = cq[2] - cq[1]

        ax_c = corner_axes[d, d]

        if 'mass' in parnames[d]:
            ax_c.set_title(math_parnames[d] + r"$ \, =\,$" + csn(cq[1], sigfigs=3) + \
            r"$\substack{+$" + csn(up_err, sigfigs=3) + r"$\\ -$" + \
            csn(low_err, sigfigs=3) + r"$}$", fontsize=11, pad=15)
        else:
            ax_c.set_title(math_parnames[d] + r"$ \, =\,$" + r"${:.3f}$".format(cq[1]) + \
            r"$\substack{+$" + r"${:.3f}$".format(up_err) + r"$\\ -$" + \
            r"${:.3f}$".format(low_err) + r"$}$", fontsize=11, pad=10)

    cornerfig.savefig(adap_dir + 'corner_' + field + '_' + \
        str(galaxy_seq) + '.pdf', dpi=200, bbox_inches='tight')

    sys.exit(0)

    # maximum a posteriori (of the locations visited by the MCMC sampler)
    pmax = np.argmax(result['lnprobability'])
    if results_type == "emcee":
        p, q = np.unravel_index(pmax, result['lnprobability'].shape)
        theta_max = result['chain'][p, q, :].copy()
    else:
        theta_max = result["chain"][pmax, :]

    #print('Optimization value: {}'.format(theta_best))
    #print('MAP value: {}'.format(theta_max))

    # make SED plot for MAP model and some random model
    # randomly chosen parameters from chain
    randint = np.random.randint
    if results_type == "emcee":
        theta = result['chain'][randint(nwalkers), randint(niter)]
    else:
        theta = result["chain"][randint(len(result["chain"]))]

    # generate models
    a = 1.0 + model.params.get('zred', 0.0)  # cosmological redshifting
    # photometric effective wavelengths
    wphot = obs["phot_wave"]
    # spectroscopic wavelengths
    if obs["wavelength"] is None:
        # *restframe* spectral wavelengths, since obs["wavelength"] is None
        wspec = sps.wavelengths
        wspec *= a  #redshift them
    else:
        wspec = obs["wavelength"]

    # sps = reader.get_sps(result)  # this works if using parameter files
    mspec, mphot, mextra = model.mean_model(theta, obs, sps=sps)
    mspec_map, mphot_map, _ = model.mean_model(theta_max, obs, sps=sps)

    # establish bounds
    xmin, xmax = np.min(wphot) * 0.8, np.max(wphot) / 0.8
    ymin, ymax = obs["maggies"].min() * 0.8, obs["maggies"].max() / 0.4

    # Make plot of data and model
    fig3 = plt.figure(figsize=(9, 4))
    ax3 = fig3.add_subplot(111)

    ax3.set_xlabel(r'$\mathrm{\lambda\ [\AA]}$', fontsize=15)
    #ax3.set_ylabel(r'$\mathrm{f_\lambda\ [erg\, s^{-1}\, cm^{-2}\, \AA^{-1}]}$', fontsize=15)
    ax3.set_ylabel(r'$\mathrm{Flux\ Density\ [maggies]}$', fontsize=15)

    ax3.loglog(wspec,
               mspec,
               label='Model spectrum (random draw)',
               lw=0.7,
               color='navy',
               alpha=0.7)
    ax3.loglog(wspec,
               mspec_map,
               label='Model spectrum (MAP)',
               lw=0.7,
               color='green',
               alpha=0.7)
    ax3.errorbar(wphot,
                 mphot,
                 label='Model photometry (random draw)',
                 marker='s',
                 markersize=10,
                 alpha=0.8,
                 ls='',
                 lw=3,
                 markerfacecolor='none',
                 markeredgecolor='blue',
                 markeredgewidth=3)
    ax3.errorbar(wphot,
                 mphot_map,
                 label='Model photometry (MAP)',
                 marker='s',
                 markersize=10,
                 alpha=0.8,
                 ls='',
                 lw=3,
                 markerfacecolor='none',
                 markeredgecolor='green',
                 markeredgewidth=3)
    ax3.errorbar(wphot,
                 obs['maggies'],
                 yerr=obs['maggies_unc'],
                 label='Observed photometry',
                 ecolor='red',
                 marker='o',
                 markersize=10,
                 ls='',
                 lw=3,
                 alpha=0.8,
                 markerfacecolor='none',
                 markeredgecolor='red',
                 markeredgewidth=3)

    # plot transmission curves
    for f in obs['filters']:
        w, t = f.wavelength.copy(), f.transmission.copy()
        t = t / t.max()
        t = 10**(0.2 * (np.log10(ymax / ymin))) * t * ymin
        ax3.loglog(w, t, lw=3, color='gray', alpha=0.7)

    ax3.set_xlim([xmin, xmax])
    ax3.set_ylim([ymin, ymax])
    ax3.legend(loc='best', fontsize=11)

    fig3.savefig(adap_dir + 'sedplot_' + field + '_' + str(galaxy_seq) +
                 '.pdf',
                 dpi=200,
                 bbox_inches='tight')

    plt.clf()
    plt.cla()
    plt.close()

    return None
Example #30
0
    #Preform and time burn-in phase
    time0 = time.time()
    pos, prob, state  = sampler.run_mcmc(pos, 100)
    sampler.reset()
    time1=time.time()
    print 'Burn-in time was '+str(time1-time0)+' seconds.'

    #Perform MCMC fit
    time0 = time.time()
    pos, prob, state  = sampler.run_mcmc(pos, 500)
    time1=time.time()
    print 'Fitting time was '+str(time1-time0)+' seconds.'

    samples = sampler.flatchain
    outfile = open("Test_%d.out" % n, 'w')
    outfile.write("Input Parameters #%d: \n" % n)
    outfile.write("Teff = %dK\n" % t)
    outfile.write("log g = %.3fK\n" % g)
    outfile.write("Bfield = %.3fK\n" % b)
    outfile.write("vsin i = %.3fK\n" % v)
    for i in range(4):
        devs = corner.quantile(samples[:,i], [0.16, 0.5, 0.84])
        outfile.write("Pameter %d : %.3f %.3f %.3f\n" % (i, devs[0], devs[1], devs[2]))
    outfile.close()
    figure = corner.corner(samples)
    figure.savefig("test_%d.png" % n)
    del(sampler)
    del(desiredParams)
    del(samples)
Example #31
0
def collate_data(alldata, **extras):

    ### preliminary stuff
    parnames = alldata[0]['pquantiles']['parnames']
    eparnames = alldata[0]['pextras']['parnames']
    xr_idx = eparnames == 'xray_lum'
    xray = prospector_io.load_xray_cat(xmatch = True, **extras)
    nsamp = 100000 # for newly defined variables

    #### for each object
    fagn, fagn_up, fagn_down, agn_tau, agn_tau_up, agn_tau_down, mass, mass_up, mass_down, lir_luv, lir_luv_up, lir_luv_down, xray_lum, xray_lum_err, database, observatory = [[] for i in range(16)]
    fagn_obs, fagn_obs_up, fagn_obs_down, lmir_lbol, lmir_lbol_up, lmir_lbol_down, xray_hardness, xray_hardness_err = [[] for i in range(8)]
    lagn, lagn_up, lagn_down, lsfr, lsfr_up, lsfr_down, lir, lir_up, lir_down = [], [], [], [], [], [], [], [], []
    sfr, sfr_up, sfr_down, ssfr, ssfr_up, ssfr_down, d2, d2_up, d2_down = [[] for i in range(9)]
    fmir, fmir_up, fmir_down, objname = [], [], [], []
    fmir_chain, agn_tau_chain = [], []
    for ii, dat in enumerate(alldata):
        objname.append(dat['objname'])

        #### mass, SFR, sSFR, dust2
        mass.append(dat['pextras']['q50'][eparnames=='stellar_mass'][0])
        mass_up.append(dat['pextras']['q84'][eparnames=='stellar_mass'][0])
        mass_down.append(dat['pextras']['q16'][eparnames=='stellar_mass'][0])
        sfr.append(dat['pextras']['q50'][eparnames=='sfr_100'][0])
        sfr_up.append(dat['pextras']['q84'][eparnames=='sfr_100'][0])
        sfr_down.append(dat['pextras']['q16'][eparnames=='sfr_100'][0])
        ssfr.append(dat['pextras']['q50'][eparnames=='ssfr_100'][0])
        ssfr_up.append(dat['pextras']['q84'][eparnames=='ssfr_100'][0])
        ssfr_down.append(dat['pextras']['q16'][eparnames=='ssfr_100'][0])
        d2.append(dat['pquantiles']['q50'][parnames=='dust2'][0])
        d2_up.append(dat['pquantiles']['q84'][parnames=='dust2'][0])
        d2_down.append(dat['pquantiles']['q16'][parnames=='dust2'][0])

        #### model f_agn, l_agn, fmir
        fagn.append(dat['pquantiles']['q50'][parnames=='fagn'][0])
        fagn_up.append(dat['pquantiles']['q84'][parnames=='fagn'][0])
        fagn_down.append(dat['pquantiles']['q16'][parnames=='fagn'][0])
        agn_tau.append(dat['pquantiles']['q50'][parnames=='agn_tau'][0])
        agn_tau_up.append(dat['pquantiles']['q84'][parnames=='agn_tau'][0])
        agn_tau_down.append(dat['pquantiles']['q16'][parnames=='agn_tau'][0])
        agn_tau_chain.append(dat['pquantiles']['sample_chain'][:,parnames=='agn_tau'].squeeze())
        lagn.append(dat['pextras']['q50'][eparnames=='l_agn'][0])
        lagn_up.append(dat['pextras']['q84'][eparnames=='l_agn'][0])
        lagn_down.append(dat['pextras']['q16'][eparnames=='l_agn'][0])
        fmir.append(dat['pextras']['q50'][eparnames=='fmir'][0])
        fmir_up.append(dat['pextras']['q84'][eparnames=='fmir'][0])
        fmir_down.append(dat['pextras']['q16'][eparnames=='fmir'][0])
        fmir_chain.append(dat['pextras']['flatchain'][:,eparnames=='fmir'].squeeze())

        #### L_UV / L_IR, LMIR/LBOL
        cent, eup, edo = quantile(np.random.choice(dat['lir'],nsamp) / np.random.choice(dat['luv'],nsamp), [0.5, 0.84, 0.16])

        lir_luv.append(cent)
        lir_luv_up.append(eup)
        lir_luv_down.append(edo)

        cent, eup, edo = quantile(np.random.choice(dat['lir'],nsamp), [0.5, 0.84, 0.16])

        lir.append(cent)
        lir_up.append(eup)
        lir_down.append(edo)

        cent, eup, edo = quantile(np.random.choice(dat['lmir'],nsamp) / np.random.choice(dat['pextras']['flatchain'][:,dat['pextras']['parnames'] == 'lbol'].squeeze(),nsamp), [0.5, 0.84, 0.16])

        lmir_lbol.append(cent)
        lmir_lbol_up.append(eup)
        lmir_lbol_down.append(edo)

        #### x-ray fluxes
        # match
        idx = xray['objname'] == dat['objname']
        if idx.sum() != 1:
            print 1/0
        xflux = xray['flux'][idx][0]
        xflux_err = xray['flux_err'][idx][0]

        # flux is in ergs / cm^2 / s, convert to erg /s 
        pc2cm =  3.08568E18
        dfactor = 4*np.pi*(dat['residuals']['phot']['lumdist']*1e6*pc2cm)**2
        xray_lum.append(xflux * dfactor)
        xray_lum_err.append(xflux_err * dfactor)
        xray_hardness.append(xray['hardness'][idx][0])
        xray_hardness_err.append(xray['hardness_err'][idx][0])

        #### CALCULATE F_AGN_OBS
        # take advantage of the already-computed conversion between FAGN (model) and LAGN (model)
        fagn_chain = dat['pquantiles']['sample_chain'][:,parnames=='fagn']
        lagn_chain = dat['pextras']['flatchain'][:,eparnames == 'l_agn']
        conversion = (lagn_chain / fagn_chain).squeeze()

        ### calculate L_AGN chain
        scale = xray_lum_err[-1]
        if scale <= 0:
            lagn_chain = np.repeat(xray_lum[-1], conversion.shape[0])
        else: 
            lagn_chain = np.random.normal(loc=xray_lum[-1], scale=scale, size=conversion.shape[0])
        obs_fagn_chain = lagn_chain / conversion
        cent, eup, edo = quantile(obs_fagn_chain, [0.5, 0.84, 0.16])

        fagn_obs.append(cent)
        fagn_obs_up.append(eup)
        fagn_obs_down.append(edo)

        ##### L_OBS / L_SFR(MODEL)
        # sample from the chain, assume gaussian errors for x-ray fluxes
        chain = dat['pextras']['flatchain'][:,xr_idx].squeeze()

        if scale <= 0:
            subchain =  np.repeat(xray_lum[-1], nsamp) / \
                        np.random.choice(chain,nsamp)
        else:
            subchain =  np.random.normal(loc=xray_lum[-1], scale=scale, size=nsamp) / \
                        np.random.choice(chain,nsamp)

        cent, eup, edo = quantile(subchain, [0.5, 0.84, 0.16])

        lsfr.append(cent)
        lsfr_up.append(eup)
        lsfr_down.append(edo)

        #### database and observatory
        database.append(str(xray['database'][idx][0]))
        try:
            observatory.append(str(xray['observatory'][idx][0]))
        except KeyError:
            observatory.append(' ')

    out = {}
    out['objname'] = objname
    out['database'] = database
    out['observatory'] = observatory
    out['mass'] = mass
    out['mass_up'] = mass_up
    out['mass_down'] = mass_down
    out['sfr'] = sfr
    out['sfr_up'] = sfr_up
    out['sfr_down'] = sfr_down
    out['ssfr'] = ssfr
    out['ssfr_up'] = ssfr_up
    out['ssfr_down'] = ssfr_down
    out['d2'] = d2
    out['d2_up'] = d2_up
    out['d2_down'] = d2_down
    out['lir_luv'] = lir_luv
    out['lir_luv_up'] = lir_luv_up
    out['lir_luv_down'] = lir_luv_down
    out['lir'] = lir
    out['lir_up'] = lir_up
    out['lir_down'] = lir_down
    out['lmir_lbol'] = lmir_lbol
    out['lmir_lbol_up'] = lmir_lbol_up
    out['lmir_lbol_down'] = lmir_lbol_down
    out['fagn'] = fagn
    out['fagn_up'] = fagn_up
    out['fagn_down'] = fagn_down
    out['agn_tau'] = agn_tau
    out['agn_tau_up'] = agn_tau_up
    out['agn_tau_down'] = agn_tau_down
    out['agn_tau_chain'] = agn_tau_chain
    out['fmir'] = fmir
    out['fmir_up'] = fmir_up
    out['fmir_down'] = fmir_down
    out['fmir_chain'] = fmir_chain
    out['fagn_obs'] = fagn_obs
    out['fagn_obs_up'] = fagn_obs_up
    out['fagn_obs_down'] = fagn_obs_down
    out['lagn'] = lagn 
    out['lagn_up'] = lagn_up
    out['lagn_down'] = lagn_down
    out['lsfr'] = lsfr
    out['lsfr_up'] = lsfr_up
    out['lsfr_down'] = lsfr_down
    out['xray_luminosity'] = xray_lum
    out['xray_luminosity_err'] = xray_lum_err
    out['xray_hardness'] = xray_hardness
    out['xray_hardness_err'] = xray_hardness_err

    for key in out.keys(): out[key] = np.array(out[key])

    #### ADD WISE PHOTOMETRY
    from wise_colors import collate_data as wise_phot
    from wise_colors import vega_conversions
    wise = wise_phot(alldata)

    #### generate x, y values
    w1w2 = -2.5*np.log10(wise['obs_phot']['wise_w1'])+2.5*np.log10(wise['obs_phot']['wise_w2'])
    w1w2 += vega_conversions('wise_w1') - vega_conversions('wise_w2')
    out['w1w2'] = w1w2

    return out
result, obs, _ = reader.results_from(adap_dir + results_type + "_" + \
                 field + "_" + str(galaxy_seq) + ".h5", dangerous=False)

# ----------- non param
nagebins = 6
agebins = np.array([[ 0.        ,  8.        ],
                    [ 8.        ,  8.47712125],
                    [ 8.47712125,  9.        ],
                    [ 9.        ,  9.47712125],
                    [ 9.47712125,  9.77815125],
                    [ 9.77815125, 10.13353891]])

samples = result['chain']

# Get the zfractions from corner quantiles
zf1 = corner.quantile(samples[:, 2], q=[0.16, 0.5, 0.84])
zf2 = corner.quantile(samples[:, 3], q=[0.16, 0.5, 0.84])
zf3 = corner.quantile(samples[:, 4], q=[0.16, 0.5, 0.84])
zf4 = corner.quantile(samples[:, 5], q=[0.16, 0.5, 0.84])
zf5 = corner.quantile(samples[:, 6], q=[0.16, 0.5, 0.84])

zf_arr = np.array([zf1[1], zf2[1], zf3[1], zf4[1], zf5[1]])

cq_mass = corner.quantile(samples[:, 7], q=[0.16, 0.5, 0.84])

new_agebins = pt.zred_to_agebins(zred=obj_z, agebins=agebins)

# now convert to sfh and its errors
sfr = pt.zfrac_to_sfr(total_mass=cq_mass[1], z_fraction=zf_arr, agebins=new_agebins)

# ----------- plot
Example #33
0
    50,
    density=True,
    histtype="step",
    weights=1.0 / trace["ecc"],
    label="$p(e) = 1$",
)
plt.xlabel("$e$")
plt.ylabel("$p(e\,|\,\mathrm{data})$")
plt.yticks([])
plt.xlim(0, 0.015)
_ = plt.legend(fontsize=12)
plt.show()

weights = 1.0 / trace["ecc"]
print("for p(e) = e/2: p(e < x) = 0.9 -> x = {0:.5f}".format(
    corner.quantile(trace["ecc"], [0.9])[0]))
print("for p(e) = 1:   p(e < x) = 0.9 -> x = {0:.5f}".format(
    corner.quantile(trace["ecc"], [0.9], weights=weights)[0]))

samples = trace["R1"]

print("for p(e) = e/2: R1 = {0:.3f} ± {1:.3f}".format(np.mean(samples),
                                                      np.std(samples)))

mean = np.sum(weights * samples) / np.sum(weights)
sigma = np.sqrt(np.sum(weights * (samples - mean)**2) / np.sum(weights))
print("for p(e) = 1:   R1 = {0:.3f} ± {1:.3f}".format(mean, sigma))

#CITATIONS

with model:
Example #34
0
def test_invalid_quantiles_1(seed=42):
    np.random.seed(seed)
    corner.quantile(np.random.rand(100), [-0.1, 5])
        #print marginalized_array
        print '#@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@'

        np.save(
            './marginalized_totalSFR/samples_values_all_' +
            str(cluster_list[l]) + '_' + str(source_name[x]) + '.npy',
            marginalized_array)
        #fig = corner.corner(marginalized_array, labels=["$z$", "$log(a)$", r"$log(L(L_{\odot}))$", r"$log(SFR(M_{\odot}/yr))$"],
        #                    quantiles=[0.16, 0.5, 0.84],
        #                    show_titles=True, title_kwargs={"fontsize": 12})

        #fig.savefig("./marginalized/triangle_marginalized_"+cluster_list[l]+"_"+str(source_name[x])+".png")
        #plt.close(fig)

        lower_limit_z = corner.quantile(marginalized_array[:, 0],
                                        q=[0.16],
                                        weights=None)
        fitted_median_z = corner.quantile(marginalized_array[:, 0],
                                          q=[0.5],
                                          weights=None)
        higher_limit_z = corner.quantile(marginalized_array[:, 0],
                                         q=[0.84],
                                         weights=None)
        z_distribution_all.append(fitted_median_z[0])
        z_distribution_all_uperr.append(higher_limit_z[0] - fitted_median_z[0])
        z_distribution_all_lowerr.append(fitted_median_z[0] - lower_limit_z[0])

        #print '2222222222222222222222222222'
        #print source_name[x]
        #print fitted_median_z
        #print lower_limit_z
Example #36
0
def test_invalid_quantiles_3(seed=42):
    np.random.seed(seed)
    corner.quantile(np.random.rand(100), [0.5, 1.0, 8.1])
Example #37
0
def plot_star(star, predict, y_names, out_dir=plot_dir):
    ## Regular histograms
    middles = np.mean(predict, 0)
    stds = np.std(predict, 0)
    
    #outstr = star
    num_ys = predict.shape[1]
    
    rows, cols, sqrt = get_rc(num_ys)
    
    plt.figure(figsize=(6.97522*2, 4.17309*2), dpi=400, 
        facecolor='w', edgecolor='k')
    for (pred_j, name) in enumerate(y_names[0:num_ys]):
        (m, s) = (middles[pred_j], stds[pred_j])
        
        if num_ys%2==0 or num_ys%3==0 or int(sqrt)==sqrt:
            ax = plt.subplot(rows, cols, pred_j+1)
        elif pred_j%2==0 and pred_j == num_ys-1:
            ax = plt.subplot2grid((rows, cols), (pred_j//2, 1), colspan=2)
        else:
            ax = plt.subplot2grid((rows, cols), (pred_j//2, (pred_j%2)*2),
                colspan=2)
        
        n, bins, patches = ax.hist(predict[:,pred_j], 50, normed=1, 
            histtype='stepfilled', color='white')
        
        #if y_central is not None:
        #    mean = np.mean(y_central[:,pred_j])
        #    std = np.std(y_central[:,pred_j])
        #    
        #    plot_line(mean, n, bins, plt, 'r--')
        #    plot_line(mean+std, n, bins, plt, 'r-.')
        #    plot_line(mean-std, n, bins, plt, 'r-.')
        
        q_16, q_50, q_84 = corner.quantile(predict[:,pred_j], [0.16, 0.5, 0.84])
        q_m, q_p = q_50-q_16, q_84-q_50
        plot_line(q_50, n, bins, plt, 'k--')
        plot_line(q_16, n, bins, plt, 'k-.')
        plot_line(q_84, n, bins, plt, 'k-.')
        
        # Format the quantile display.
        fmt = "{{0:{0}}}".format(".3g").format
        title = r"${{{0}}}_{{-{1}}}^{{+{2}}}$"
        title = title.format(fmt(q_50), fmt(q_m), fmt(q_p))
        
        ax.annotate(r"$\epsilon = %.3g\%%$" % (s/m*100),
            xy=(0.99, 0.12), xycoords='axes fraction',
            horizontalalignment='right', verticalalignment='right')
        
        P.xlabel(y_latex[y_names[pred_j]] + " = " + title)
        P.locator_params(axis='x', nbins=3)
        
        xs = [max(0, m-4*s), m+4*s]
        
        xticks = [max(0, m-3*s), m, m+3*s]
        ax.set_xticks(xticks)
        ax.set_xlim(xs)
        ax.set_xticklabels(['%.3g'%xtick for xtick in xticks])
        ax.set_yticklabels('',visible=False)
        
        ax.set_frame_on(False)
        ax.get_xaxis().tick_bottom()
        ax.axes.get_yaxis().set_visible(False)
        
        #xmin, xmax = ax1.get_xaxis().get_view_interval()
        ymin, ymax = ax.get_yaxis().get_view_interval()
        ax.add_artist(mpl.lines.Line2D(xs, (ymin, ymin), 
            color='black', linewidth=2))
        
        #ax.minorticks_on()
        plt.tight_layout()
    
    plt.savefig(os.path.join(out_dir, star + '.pdf'), dpi=400)
    plt.close()
Example #38
0
def test_invalid_quantiles_2(seed=42):
    np.random.seed(seed)
    corner.quantile(np.random.rand(100), 5)
Example #39
0
def read_pickle_make_plots_sn(object_type, ndim, args_obj, label_list,
                              truth_dict, savedir):

    h5_path = savedir + 'emcee_sampler_' + object_type + '.h5'
    sampler = emcee.backends.HDFBackend(h5_path)

    samples = sampler.get_chain()
    print("\nRead in sampler:", h5_path)
    print("Samples shape:", samples.shape)

    #reader = emcee.backends.HDFBackend(pkl_path.replace('.pkl', '.h5'))
    #samples = reader.get_chain()
    #tau = reader.get_autocorr_time(tol=0)

    # Get autocorrelation time
    # Discard burn-in. You do not want to consider the burn in the corner plots/estimation.
    tau = sampler.get_autocorr_time(tol=0)
    if not np.any(np.isnan(tau)):
        burn_in = int(2 * np.max(tau))
        thinning_steps = int(0.5 * np.min(tau))
    else:
        burn_in = 200
        thinning_steps = 30

    print("Average Tau:", np.mean(tau))
    print("Burn-in:", burn_in)
    print("Thinning steps:", thinning_steps)

    # construct truth arr and plot
    truth_arr = np.array(
        [truth_dict['z'], truth_dict['phase'], truth_dict['Av']])

    # plot trace
    fig1, axes1 = plt.subplots(ndim, figsize=(10, 6), sharex=True)

    for i in range(ndim):
        ax1 = axes1[i]
        ax1.plot(samples[:, :, i], "k", alpha=0.05)
        ax1.axhline(y=truth_arr[i], color='tab:red', lw=2.0)
        ax1.set_xlim(0, len(samples))
        ax1.set_ylabel(label_list[i], fontsize=15)
        ax1.yaxis.set_label_coords(-0.1, 0.5)

    axes1[-1].set_xlabel("Step number")

    fig1.savefig(savedir + 'emcee_trace_' + object_type + '.pdf',
                 dpi=200,
                 bbox_inches='tight')

    # Create flat samples
    flat_samples = sampler.get_chain(discard=burn_in,
                                     thin=thinning_steps,
                                     flat=True)
    print("\nFlat samples shape:", flat_samples.shape)

    # plot corner plot
    cq_z = corner.quantile(x=flat_samples[:, 0], q=[0.16, 0.5, 0.84])
    cq_day = corner.quantile(x=flat_samples[:, 1], q=[0.16, 0.5, 0.84])
    cq_av = corner.quantile(x=flat_samples[:, 2], q=[0.16, 0.5, 0.84])

    # print parameter estimates
    print(f"{bcolors.CYAN}")
    print("Parameter estimates:")
    print("Redshift: ", cq_z)
    print("Supernova phase [day]:", cq_day)
    print("Visual extinction [mag]:", cq_av)
    print(f"{bcolors.ENDC}")

    fig = corner.corner(flat_samples,
                        quantiles=[0.16, 0.5, 0.84],
                        labels=label_list,
                        label_kwargs={"fontsize": 14},
                        show_titles='True',
                        title_kwargs={"fontsize": 14},
                        truth_color='tab:red',
                        truths=truth_arr,
                        smooth=0.5,
                        smooth1d=0.5)

    # Extract the axes
    axes = np.array(fig.axes).reshape((ndim, ndim))

    # Get the redshift axis
    # and edit how the errors are displayed
    ax_z = axes[0, 0]

    z_err_high = cq_z[2] - cq_z[1]
    z_err_low = cq_z[1] - cq_z[0]

    ax_z.set_title(r"$z \, =\,$" + r"${:.3f}$".format(cq_z[1]) + \
        r"$\substack{+$" + r"${:.3f}$".format(z_err_high) + r"$\\ -$" + \
        r"${:.3f}$".format(z_err_low) + r"$}$",
        fontsize=11)

    fig.savefig(savedir + 'corner_' + object_type + '.pdf',
                dpi=200,
                bbox_inches='tight')

    # ------------ Plot 100 random models from the parameter
    # space within +-1sigma of corner estimates
    # first pull out required stuff from args
    wav = args_obj[0]
    flam = args_obj[1]
    ferr = args_obj[2]

    fig3 = plt.figure(figsize=(9, 4))
    ax3 = fig3.add_subplot(111)

    ax3.set_xlabel(r'$\mathrm{\lambda\ [\AA]}$', fontsize=15)
    ax3.set_ylabel(
        r'$\mathrm{f_\lambda\ [erg\, s^{-1}\, cm^{-2}\, \AA^{-1}]}$',
        fontsize=15)

    model_count = 0
    ind_list = []

    while model_count <= 200:

        ind = int(np.random.randint(len(flat_samples), size=1))
        ind_list.append(ind)

        # make sure sample has correct shape
        sample = flat_samples[ind]

        model_okay = 0

        sample = sample.reshape(3)

        # Get the parameters of the sample
        model_z = sample[0]
        model_day = sample[1]
        model_av = sample[2]

        # Check that the model is within +-1 sigma
        # of value inferred by corner contours
        if (model_z >= cq_z[0]) and (model_z <= cq_z[2]) and \
           (model_day >= cq_day[0]) and (model_day <= cq_day[2]) and \
           (model_av >= cq_av[0]) and (model_av <= cq_av[2]):

            model_okay = 1

        # Now plot if the model is okay
        if model_okay:

            m = model_sn(wav, sample[0], sample[1], sample[2])

            a = np.nansum(flam * m / ferr**2) / np.nansum(m**2 / ferr**2)
            m = m * a

            ax3.plot(wav, m, color='royalblue', lw=0.5, alpha=0.05, zorder=2)

            model_count += 1

    ax3.plot(wav, flam, color='k', lw=1.0, zorder=1)
    ax3.fill_between(wav,
                     flam - ferr,
                     flam + ferr,
                     color='gray',
                     alpha=0.5,
                     zorder=1)

    # ADD LEGEND
    ax3.text(x=0.65,
             y=0.92,
             s='--- Simulated data',
             verticalalignment='top',
             horizontalalignment='left',
             transform=ax3.transAxes,
             color='k',
             size=12)
    ax3.text(x=0.65,
             y=0.85,
             s='--- 200 randomly chosen samples',
             verticalalignment='top',
             horizontalalignment='left',
             transform=ax3.transAxes,
             color='royalblue',
             size=12)

    fig3.savefig(savedir + 'emcee_overplot_' + object_type + '.pdf',
                 dpi=200,
                 bbox_inches='tight')

    # Close all figures
    fig1.clear()
    fig.clear()
    fig3.clear()

    #plt.clf()
    #plt.cla()
    plt.close(fig1)
    plt.close(fig)
    plt.close(fig3)

    return None
Example #40
0
def collate_data(alldata,alldata_noagn):

	### number of random draws
	size = 10000

	### normal parameter labels
	parnames = alldata_noagn[0]['pquantiles']['parnames'].tolist()
	parnames2 = alldata[0]['pquantiles']['parnames'].tolist()
	parlabels = [r'log(M$_{\mathrm{form}}$/M$_{\odot}$)', 'SFH 0-100 Myr', 'SFH 100-300 Myr', 'SFH 300 Myr-1 Gyr', 
	         'SFH 1-3 Gyr', 'SFH 3-6 Gyr', r'$\tau_{\mathrm{V,diffuse}}$', r'log(Z/Z$_{\odot}$)', 'diffuse dust index',
	         'birth-cloud dust', r'dust emission Q$_{\mathrm{PAH}}$',r'dust emission $\gamma$',r'dust emission U$_{\mathrm{min}}$']

	### extra parameters
	eparnames_all = alldata[0]['pextras']['parnames']
	eparnames = ['stellar_mass','sfr_100', 'ssfr_100', 'half_time']
	eparlabels = [r'log(M$_*$) [M$_{\odot}$]',r'log(SFR) [M$_{\odot}$ yr$^{-1}$]',r'log(sSFR) [yr$^{-1}$]', r"log(t$_{\mathrm{half-mass}}$) [Gyr]"]

	### let's do something special here
	fparnames = ['halpha','m23_frac']
	fparlabels = [r'log(H$_{\alpha}$ flux)',r'M$_{\mathrm{0.1-1 Gyr}}$/M$_{\mathrm{total}}$']
	objname = []

	### setup dictionary
	outvals, outq, outerrs, outlabels = {},{},{},{}
	alllabels = parlabels + eparlabels + fparlabels
	for ii,par in enumerate(parnames+eparnames+fparnames): 
		outvals[par] = []
		outq[par] = {}
		outq[par]['q50'],outq[par]['q84'],outq[par]['q16'] = [],[],[]
		outlabels[par] = alllabels[ii]

	### fill with data
	for dat,datnoagn in zip(alldata,alldata_noagn):
		objname.append(dat['objname'])
		for ii,par in enumerate(parnames):
			p1 = np.random.choice(dat['pquantiles']['sample_chain'][:,ii].squeeze(),size=size)
			p2 = np.random.choice(datnoagn['pquantiles']['sample_chain'][:,ii].squeeze(),size=size)
			ratio = p1 - p2
			for q in outq[par].keys(): 
				quant = float(q[1:])/100
				outq[par][q].append(quantile(ratio, [quant])[0])
			outvals[par].append(outq[par]['q50'][-1])
		for par in eparnames:
			match = eparnames_all == par
			match2 = datnoagn['pextras']['parnames'] == par
			p1 = np.random.choice(np.log10(dat['pextras']['flatchain'][:,match]).squeeze(),size=size)
			p2 = np.random.choice(np.log10(datnoagn['pextras']['flatchain'][:,match2]).squeeze(),size=size)
			ratio = p1 - p2
			for q in outq[par].keys(): 
				quant = float(q[1:])/100
				outq[par][q].append(quantile(ratio, [quant])[0])
			outvals[par].append(outq[par]['q50'][-1])

		par = 'halpha'
		ha_idx = dat['model_emline']['emnames'] == 'Halpha'
		p1 = np.random.choice(np.log10(dat['model_emline']['flux']['chain'][:,ha_idx]).squeeze(),size=size)
		p2 = np.random.choice(np.log10(datnoagn['model_emline']['flux']['chain'][:,ha_idx]).squeeze(),size=size)
		ratio = p1 - p2
		for q in outq[par].keys(): 
			quant = float(q[1:])/100
			outq[par][q].append(quantile(ratio, [quant])[0])
		outvals[par].append(outq[par]['q50'][-1])

		### this is super ugly but it works
		# calculate tuniv, create agelim array
		par = 'm23_frac'
		zfrac_idx = np.array(['z_fraction' in p for p in parnames],dtype=bool)
		zfrac_idx2 = np.array(['z_fraction' in p for p in parnames2],dtype=bool)

		tuniv = WMAP9.age(dat['residuals']['phot']['z']).value
		agelims = [0.0,8.0,8.5,9.0,9.5,9.8,10.0]
		agelims[-1] = np.log10(tuniv*1e9)
		time_per_bin = []
		for i in xrange(len(agelims)-1): time_per_bin.append(10**agelims[i+1]-10**agelims[i])

		# now calculate fractions for each of them
		sfrfrac = transform_zfraction_to_sfrfraction(dat['pquantiles']['sample_chain'][:,zfrac_idx2])
		full = np.concatenate((sfrfrac,(1-sfrfrac.sum(axis=1))[:,None]),axis=1)
		mass_fraction = full*np.array(time_per_bin)
		mass_fraction /= mass_fraction.sum(axis=1)[:,None]
		m23_agn = mass_fraction[:,1:3].sum(axis=1)

		sfrfrac = transform_zfraction_to_sfrfraction(datnoagn['pquantiles']['sample_chain'][:,zfrac_idx])
		full = np.concatenate((sfrfrac,(1-sfrfrac.sum(axis=1))[:,None]),axis=1)
		mass_fraction = full*np.array(time_per_bin)
		mass_fraction /= mass_fraction.sum(axis=1)[:,None]
		m23_noagn = mass_fraction[:,1:3].sum(axis=1)

		ratio = np.random.choice(m23_agn,size=size) - np.random.choice(m23_noagn,size=size)
		for q in outq[par].keys(): 
			quant = float(q[1:])/100
			outq[par][q].append(quantile(ratio, [quant])[0])
		outvals[par].append(outq[par]['q50'][-1])

	### do the errors
	for par in outlabels.keys():
		outerrs[par] = asym_errors(np.array(outq[par]['q50']), 
				                   np.array(outq[par]['q84']),
				                   np.array(outq[par]['q16']),log=False)
		outvals[par] = np.array(outvals[par])

	### AGN parameters
	agn_pars = {}
	pnames = ['fagn', 'agn_tau']
	for p in pnames: agn_pars[p] = []
	agn_parnames = alldata[0]['pquantiles']['parnames']
	for dat in alldata:
		for key in agn_pars.keys():
			agn_pars[key].append(dat['pquantiles']['q50'][agn_parnames==key][0])

	### fill output
	out = {}
	out['median'] = outvals
	out['errs'] = outerrs
	out['labels'] = outlabels
	out['ordered_labels'] = np.concatenate((eparnames,np.array(parnames),np.array(fparnames)))
	out['agn_pars'] = agn_pars
	out['objname'] = objname
	return out
except:
    print('file not found')

print('sps and model')
sps = build_sps()
mod = build_model()
thetas = mod.theta_labels()
#print(mod)
thetas_50 = []
thetas_16 = []
thetas_84 = []
print('quantiles for all thetas')
for theta in thetas:
    idx = thetas.index(theta)
    chain = [item[idx] for item in res['chain']]
    quan = quantile(chain, [.16, .5, .84])
    thetas_50.append(quan[1])
    thetas_16.append(quan[0])
    thetas_84.append(quan[2])

mod_50 = mod.mean_model(thetas_50, obs, sps)
massfrac_50 = mod_50[-1]
mod_16 = mod.mean_model(thetas_16, obs, sps)
massfrac_16 = mod_16[-1]
mod_84 = mod.mean_model(thetas_84, obs, sps)
massfrac_84 = mod_84[-1]

print('mass and Z')

mass_50 = thetas_50[thetas.index('massmet_1')]
mass_16 = thetas_16[thetas.index('massmet_1')]