def get_supernovae(n, data=True): redshifts = RedshiftSampler() # Redshift distribution zs = redshifts.sample(size=n) # import matplotlib.pyplot as plt # plt.hist(zs, 100) # plt.show() # exit() # Population stats vals = get_truths_labels_significance() mapping = {k[0]: k[1] for k in vals} cosmology = FlatwCDM(70.0, mapping["Om"]) mus = cosmology.distmod(zs).value alpha = mapping["alpha"] beta = mapping["beta"] dscale = mapping["dscale"] dratio = mapping["dratio"] p_high_masses = np.random.uniform(low=0.0, high=1.0, size=n) means = np.array([mapping["mean_MB"], mapping["mean_x1"], mapping["mean_c"]]) sigmas = np.array([mapping["sigma_MB"], mapping["sigma_x1"], mapping["sigma_c"]]) sigmas_mat = np.dot(sigmas[:, None], sigmas[None, :]) correlations = np.dot(mapping["intrinsic_correlation"], mapping["intrinsic_correlation"].T) pop_cov = correlations * sigmas_mat results = [] for z, p, mu in zip(zs, p_high_masses, mus): try: MB, x1, c = np.random.multivariate_normal(means, pop_cov) mass_correction = dscale * (1.9 * (1 - dratio) / (0.9 + np.power(10, 0.95 * z)) + dratio) adjustment = - alpha * x1 + beta * c - mass_correction * p MB_adj = MB + adjustment mb = MB_adj + mu result = get_ia_summary_stats(z, MB_adj, x1, c, cosmo=cosmology, data=data) d = { "MB": MB, "mB": mb, "x1": x1, "c": c, "m": p, "z": z, "pc": result["passed_cut"], "lp": multivariate_normal.logpdf([MB, x1, c], means, pop_cov), "dp": result.get("delta_p"), "parameters": result.get("params"), "covariance": result.get("cov"), "lc": None if data else result.get("lc") } results.append(d) except RuntimeError: print("Error on nova: %0.2f %0.2f %0.2f %0.3f" % (MB, x1, c, z)) return results
def load_stan_from_folder(folder, replace=True, merge=True, cut=False, num=None): vals = get_truths_labels_significance() full_params = [[k[2]] if not isinstance(k[2], list) else k[2] for k in vals if k[2] is not None] params = [[k[2]] if not isinstance(k[2], list) else k[2] for k in vals if k[3] and k[2] is not None] full_params = list(itertools.chain.from_iterable(full_params)) full_params.remove("$\\rho$") params = list(itertools.chain.from_iterable(params)) name_map = {k[0]: k[2] for k in vals} truths = {k[2]: k[1] for k in vals if not isinstance(k[2], list)} is_array = [k[0] for k in vals if not isinstance(k[1], float) and not isinstance(k[1], int)] cs = {} fs = sorted([f for f in os.listdir(folder) if f.startswith("stan") and f.endswith(".pkl")]) if num is not None: filter = "_%d_" % num fs = [f for f in fs if filter in f] for f in fs: splits = f.split("_") c = splits[1] t = os.path.abspath(folder + os.sep + f) if cs.get(c) is None: cs[c] = [] cs[c].append(get_chain(t, name_map, replace=replace)) assert len(cs.keys()) > 0, "No results found" result = [] good_ks = [] for k in sorted(list(cs.keys())): chains = cs[k] chain = chains[0] for c in chains[1:]: for key in chain.keys(): chain[key] = np.concatenate((chain[key], c[key])) posterior = chain["Posterior"] del chain["Posterior"] if "weight" in chain.keys(): weights = chain["weight"] del chain["weight"] else: weights = np.ones(posterior.shape) if "old\\_weight" in chain.keys(): ow = chain["old\\_weight"] # ow -= ow.min() # ow = np.exp(ow) del chain["old\\_weight"] elif "old_weight" in chain.keys(): ow = chain["old_weight"] del chain["old_weight"] else: ow = np.ones(posterior.shape) print(chain.keys()) for param in is_array: latex = name_map[param] truth_val = truths[latex] shape = truth_val.shape if not replace: del chain[param] if len(shape) > 1 or latex not in chain: continue # Dont do 2D parameters for i in range(shape[0]): column = chain[latex][:, i] chain[latex % i] = column truths[latex % i] = truth_val[i] del chain[latex] c = ChainConsumer() c.add_chain(chain, weights=weights) summary = c.get_summary() num_failed = sum([1 if summary[k][0] is None else 0 for k in summary.keys()]) num_param = len(list(summary.keys())) if not cut or num_failed < 4: print("Chain %s good" % k) good_ks.append(k) result.append((chain, posterior, truths, params, full_params, len(chains), weights, ow)) else: print("Chain %s is bad" % k) if merge: rr = list(result[0]) for r in result[1:]: for key in rr[0].keys(): rr[0][key] = np.concatenate((rr[0][key], r[0][key])) rr[1] = np.concatenate((rr[1], r[1])) rr[6] = np.concatenate((rr[6], r[6])) rr[7] = np.concatenate((rr[7], r[7])) rr[5] += r[5] return tuple(rr) else: return result