def calculate_kld(pe, qe, vb=False): """ Calculates the Kullback-Leibler Divergence between two PDFs. Parameters ---------- pe: numpy.ndarray, float probability distribution evaluated on a grid whose distance from `q` will be calculated. qe: numpy.ndarray, float probability distribution evaluated on a grid whose distance to `p` will be calculated. vb: boolean report on progress to stdout? Returns ------- Dpq: float the value of the Kullback-Leibler Divergence from `q` to `p` """ # Normalize the evaluations, so that the integrals can be done # (very approximately!) by simple summation: pn = pe / np.sum(pe) qn = qe / np.sum(qe) # Compute the log of the normalized PDFs logp = u.safe_log(pn) logq = u.safe_log(qn) # Calculate the KLD from q to p Dpq = np.sum(pn * (logp - logq)) return Dpq
def calculate_mexp(self, vb=False): """ Calculates the marginalized expected value estimator of the redshift density function Parameters ---------- vb: boolean, optional True to print progress messages to stdout, False to suppress Returns ------- log_exp_nz: ndarray, float array of logged redshift density function bin values """ if 'log_mexp_nz' not in self.info['estimators']: expprep = [ sum(z) for z in self.bin_mids * self.pdfs * self.bin_difs ] self.exp_nz = np.zeros(self.n_bins) for z in expprep: for k in range(self.n_bins): if z > self.bin_ends[k] and z < self.bin_ends[k + 1]: self.exp_nz[k] += 1. self.exp_nz /= self.bin_difs * self.n_pdfs self.log_exp_nz = u.safe_log(self.exp_nz) self.info['estimators']['log_mexp_nz'] = self.log_exp_nz else: self.log_exp_nz = self.info['estimators']['log_mexp_nz'] self.exp_nz = np.exp(self.log_exp_nz) return self.log_exp_nz
def evaluate_log_hyper_likelihood(self, log_nz): """ Function to evaluate log hyperlikelihood Parameters ---------- log_nz: numpy.ndarray, float vector of logged redshift density bin values at which to evaluate the hyperlikelihood Returns ------- log_hyper_likelihood: float log likelihood probability associated with parameters in log_nz """ nz = np.exp(log_nz) norm_nz = nz / np.dot(nz, self.bin_difs) # testing whether the norm step is still necessary hyper_lfs = np.sum(norm_nz[None, :] * self.pdfs / self.int_pr[None, :] * self.bin_difs, axis=1) log_hyper_likelihood = np.sum(u.safe_log(hyper_lfs)) - np.log( np.dot(norm_nz, self.bin_difs)) # this used to work... # log_hyper_likelihood = np.dot(np.exp(log_nz + self.precomputed), self.bin_difs) return log_hyper_likelihood
def calculate_mmap(self, vb=False): """ Calculates the marginalized maximum a posteriori estimator of the redshift density function Parameters ---------- vb: boolean, optional True to print progress messages to stdout, False to suppress Returns ------- log_map_nz: ndarray, float array of logged redshift density function bin values """ if 'log_mmap_nz' not in self.info['estimators']: self.map_nz = np.zeros(self.n_bins) mappreps = [np.argmax(l) for l in self.log_pdfs] for m in mappreps: self.map_nz[m] += 1. self.map_nz /= self.bin_difs[m] * self.n_pdfs self.log_map_nz = u.safe_log(self.map_nz) self.info['estimators']['log_mmap_nz'] = self.log_map_nz else: self.log_map_nz = self.info['estimators']['log_mmap_nz'] self.map_nz = np.exp(self.log_map_nz) return self.log_map_nz
def calculate_mmle(self, start, vb=False, no_data=0, no_prior=0): """ Calculates the marginalized maximum likelihood estimator of the redshift density function Parameters ---------- start: numpy.ndarray, float array of log redshift density function bin values at which to begin optimization vb: boolean, optional True to print progress messages to stdout, False to suppress no_data: boolean, optional True to exclude data contribution to hyperposterior no_prior: boolean, optional True to exclude prior contribution to hyperposterior Returns ------- log_mle_nz: numpy.ndarray, float array of logged redshift density function bin values maximizing hyperposterior """ # self.precomputed = self.precompute() if 'log_mmle_nz' not in self.info['estimators']: log_mle = self.optimize(start, no_data=no_data, no_prior=no_prior, vb=vb) mle_nz = np.exp(log_mle) self.mle_nz = mle_nz / np.dot(mle_nz, self.bin_difs) self.log_mle_nz = u.safe_log(self.mle_nz) self.info['estimators']['log_mmle_nz'] = self.log_mle_nz else: self.log_mle_nz = self.info['estimators']['log_mmle_nz'] self.mle_nz = np.exp(self.log_mle_nz) return self.log_mle_nz
def plot_prob_space(z_grid, p_space, plot_loc='', prepend='', plot_name='prob_space.png'): """ Plots the 2D probability space of z_spec, z_phot Parameters ---------- p_space: numpy.ndarray, float probabilities on the grid z_grid: numpy.ndarray, float fine grid of redshifts plot_loc: string, optional location in which to store plot plot_name: string, optional filename for plot prepend: str, optional prepend string to plot name """ pu.set_up_plot() f = plt.figure(figsize=(5, 5)) plt.subplot(1, 1, 1) grid_len = len(z_grid) grid_range = range(grid_len) # to_plot = u.safe_log(p_space.evaluate(all_points.reshape((grid_len**2, 2))).reshape((grid_len, grid_len))) # to_plot.reshape((len(z_grid), len(z_grid))) # plt.pcolormesh(z_grid, z_grid, to_plot, cmap='viridis') all_points = np.array([[(z_grid[kk], z_grid[jj]) for kk in grid_range] for jj in grid_range]) orig_shape = np.shape(all_points) # all_vals = np.array([[p_space.evaluate_one(np.array([z_grid[jj], z_grid[kk]])) for jj in range(len(z_grid))] for kk in range(len(z_grid))]) all_vals = p_space.evaluate(all_points.reshape((orig_shape[0]*orig_shape[1], orig_shape[2]))).reshape((orig_shape[0], orig_shape[1])) plt.pcolormesh(z_grid, z_grid, u.safe_log(all_vals), cmap='viridis') plt.plot(z_grid, z_grid, color='k') plt.colorbar() plt.xlabel(r'$z_{\mathrm{true}}$') plt.ylabel(r'$\mathrm{``data"}$')#z_{\mathrm{phot}}$') plt.axis([z_grid[0], z_grid[-1], z_grid[0], z_grid[-1]]) f.savefig(os.path.join(plot_loc, prepend+plot_name), bbox_inches='tight', pad_inches = 0, dpi=d.dpi) return
def calculate_stacked(self, vb=False): """ Calculates the stacked estimator of the redshift density function Parameters ---------- vb: boolean, optional True to print progress messages to stdout, False to suppress Returns ------- log_stk_nz: ndarray, float array of logged redshift density function bin values """ if 'log_stacked_nz' not in self.info['estimators']: self.stk_nz = np.sum(self.pdfs, axis=0) self.stk_nz /= np.dot(self.stk_nz, self.bin_difs) self.log_stk_nz = u.safe_log(self.stk_nz) self.info['estimators']['log_stacked_nz'] = self.log_stk_nz else: self.log_stk_nz = self.info['estimators']['log_stacked_nz'] self.stk_nz = np.exp(self.log_stk_nz) return self.log_stk_nz
def plot_samples(info, plot_dir, prepend=''): """ Plots a few random samples from the posterior distribution Parameters ---------- info: dict dictionary of stored information from log_z_dens object plot_dir: string directory where plot should be stored prepend: str, optional prepend string to file names """ pu.set_up_plot() f = plt.figure(figsize=(5, 10)) sps = [f.add_subplot(2, 1, l + 1) for l in xrange(0, 2)] f.subplots_adjust(hspace=0, wspace=0) sps_log = sps[0] sps = sps[1] sps_log.set_xlim(info['bin_ends'][0], info['bin_ends'][-1]) sps_log.set_ylabel(r'$\ln[n(z)]$') sps.set_xlim(info['bin_ends'][0], info['bin_ends'][-1]) sps.set_xlabel(r'$z$') sps.set_ylabel(r'$n(z)$') sps.ticklabel_format(style='sci', axis='y') pu.plot_step(sps, info['bin_ends'], np.exp(info['log_interim_prior']), w=w_int, s=s_int, a=a_int, c=c_int, d=d_int, l=l_int + nz) pu.plot_step(sps_log, info['bin_ends'], info['log_interim_prior'], w=w_int, s=s_int, a=a_int, c=c_int, d=d_int, l=l_int + lnz) if info['truth'] is not None: sps.plot(info['truth']['z_grid'], info['truth']['nz_grid'], linewidth=w_tru, alpha=a_tru, color=c_tru, label=l_tru + nz) sps_log.plot(info['truth']['z_grid'], u.safe_log(info['truth']['nz_grid']), linewidth=w_tru, alpha=a_tru, color=c_tru, label=l_tru + lnz) (locs, scales) = s.norm_fit(info['log_sampled_nz_meta_data']['chains']) for k in range(len(info['bin_ends']) - 1): x_errs = [ info['bin_ends'][k], info['bin_ends'][k], info['bin_ends'][k + 1], info['bin_ends'][k + 1] ] log_y_errs = [ locs[k] - scales[k], locs[k] + scales[k], locs[k] + scales[k], locs[k] - scales[k] ] sps_log.fill(x_errs, log_y_errs, color='k', alpha=0.1, linewidth=0.) sps.fill(x_errs, np.exp(log_y_errs), color='k', alpha=0.1, linewidth=0.) shape = np.shape(info['log_sampled_nz_meta_data']['chains']) flat = info['log_sampled_nz_meta_data']['chains'].reshape( np.prod(shape[:-1]), shape[-1]) random_samples = [ np.random.randint(0, len(flat)) for i in range(d.plot_colors) ] for i in range(d.plot_colors): pu.plot_step(sps_log, info['bin_ends'], flat[random_samples[i]], s=s_smp, d=d_smp, w=w_smp, a=1., c=pu.colors[i]) pu.plot_step(sps, info['bin_ends'], np.exp(flat[random_samples[i]]), s=s_smp, d=d_smp, w=w_smp, a=1., c=pu.colors[i]) pu.plot_step(sps_log, info['bin_ends'], locs, s=s_smp, d=d_smp, w=2., a=1., c='k', l=l_bfe + lnz) pu.plot_step(sps, info['bin_ends'], np.exp(locs), s=s_smp, d=d_smp, w=2., a=1., c='k', l=l_bfe + nz) sps_log.legend(fontsize='x-small', loc='lower left') sps.set_xlabel('x') sps_log.set_ylabel('Log probability density') sps.set_ylabel('Probability density') f.savefig(os.path.join(plot_dir, prepend + 'samples.png'), bbox_inches='tight', pad_inches=0) return
def plot_ivals(ivals, info, plot_dir, prepend=''): """ Plots the initial values given to the sampler Parameters ---------- ivals: np.ndarray, float (n_walkers, n_bins) array of initial values for sampler info: dict dictionary of stored information from log_z_dens object plot_dir: string location into which the plot will be saved prepend: str, optional prepend string to file names Returns ------- f: matplotlib figure figure object """ pu.set_up_plot() n_walkers = len(ivals) walkers = [np.random.randint(0, n_walkers) for i in range(d.plot_colors)] f = plt.figure(figsize=(10, 5)) sps_samp = f.add_subplot(1, 2, 1) for i in range(d.plot_colors): pu.plot_step(sps_samp, info['bin_ends'], ivals[walkers[i]], c=pu.colors[i]) pu.plot_step(sps_samp, info['bin_ends'], info['log_interim_prior'], w=w_int, s=s_int, a=a_int, c=c_int, d=d_int, l=l_int + nz) if info['truth'] is not None: sps_samp.plot(info['truth']['z_grid'], np.log(info['truth']['nz_grid']), linewidth=w_tru, alpha=a_tru, color=c_tru, label=l_tru + nz) sps_samp.set_xlabel(r'$z$') sps_samp.set_ylabel(r'$\ln\left[n(z)\right]$') sps_sum = f.add_subplot(1, 2, 2) bin_difs = info['bin_ends'][1:] - info['bin_ends'][:-1] ival_integrals = np.dot(np.exp(ivals), bin_difs) log_ival_integrals = u.safe_log(ival_integrals) sps_sum.hist(log_ival_integrals, color='k', normed=1) sps_sum.vlines(np.log(np.dot(np.exp(info['log_interim_prior']), bin_difs)), 0., 1., linewidth=w_int, linestyle=s_int, alpha=a_int, color=c_int, dashes=d_int, label=l_int + nz) sps_sum.vlines(np.mean(log_ival_integrals), 0., 1., linewidth=w_bfe, linestyle=s_bfe, alpha=a_bfe, color=c_bfe, dashes=d_bfe, label=l_bfe + lnz) sps_sum.set_xlabel(r'$\ln\left[\int n(z)dz\right]$') sps_sum.set_ylabel(r'$p\left(\ln\left[\int n(z)dz\right]\right)$') f.savefig(os.path.join(plot_dir, prepend + 'ivals.png'), bbox_inches='tight', pad_inches=0) return
def calculate_samples(self, ivals, n_accepted=d.n_accepted, n_burned=d.n_burned, vb=False, n_procs=1, no_data=0, no_prior=0): """ Calculates samples estimating the redshift density function Parameters ---------- ivals: numpy.ndarray, float initial values of log n(z) for each walker n_accepted: int, optional log10 number of samples to accept per walker n_burned: int, optional log10 number of samples between tests of burn-in condition n_procs: int, optional number of processors to use, defaults to single-thread vb: boolean, optional True to print progress messages to stdout, False to suppress no_data: boolean, optional True to exclude data contribution to hyperposterior no_prior: boolean, optional True to exclude prior contribution to hyperposterior Returns ------- log_samples_nz: ndarray, float array of sampled log redshift density function bin values """ # self.precomputed = self.precompute() if 'log_mean_sampled_nz' not in self.info['estimators']: self.n_walkers = len(ivals) if no_data: def distribution(log_nz): return self.evaluate_log_hyper_prior(log_nz) elif no_prior: def distribution(log_nz): return self.evaluate_log_hyper_likelihood(log_nz) else: def distribution(log_nz): return self.evaluate_log_hyper_posterior(log_nz) self.sampler = emcee.EnsembleSampler(self.n_walkers, self.n_bins, distribution, threads=n_procs) self.burn_ins = 0 if n_burned == 0: self.burning_in = False else: self.burning_in = True vals = ivals vals -= u.safe_log( np.sum(np.exp(ivals) * self.bin_difs[np.newaxis, :], axis=1))[:, np.newaxis] if vb: plots.plot_ivals(vals, self.info, self.plot_dir, prepend=self.add_text) canvas = plots.set_up_burn_in_plots(self.n_bins, self.n_walkers) full_chain = np.array([[vals[w]] for w in range(self.n_walkers)]) while self.burning_in: if vb: print('beginning sampling ' + str(self.burn_ins)) burn_in_mcmc_outputs = self.sample(vals, 10**n_burned) chain = burn_in_mcmc_outputs['chains'] burn_in_mcmc_outputs['chains'] -= u.safe_log( np.sum(np.exp(chain) * self.bin_difs[np.newaxis, np.newaxis, :], axis=2))[:, :, np.newaxis] with open( os.path.join(self.res_dir, 'mcmc' + str(self.burn_ins) + '.p'), 'wb') as file_location: cpkl.dump(burn_in_mcmc_outputs, file_location) full_chain = np.concatenate( (full_chain, burn_in_mcmc_outputs['chains']), axis=1) if vb: canvas = plots.plot_sampler_progress(canvas, burn_in_mcmc_outputs, full_chain, self.burn_ins, self.plot_dir, prepend=self.add_text) self.burning_in = s.gr_test(full_chain) vals = np.array( [item[-1] for item in burn_in_mcmc_outputs['chains']]) self.burn_ins += 1 mcmc_outputs = self.sample(vals, 10**n_accepted) chain = mcmc_outputs['chains'] mcmc_outputs['chains'] -= u.safe_log( np.sum(np.exp(chain) * self.bin_difs[np.newaxis, np.newaxis, :], axis=2))[:, :, np.newaxis] full_chain = np.concatenate((full_chain, mcmc_outputs['chains']), axis=1) with open(os.path.join(self.res_dir, 'full_chain.p'), 'wb') as file_location: cpkl.dump(full_chain, file_location) self.log_smp_nz = mcmc_outputs['chains'] self.smp_nz = np.exp(self.log_smp_nz) self.info['log_sampled_nz_meta_data'] = mcmc_outputs self.log_bfe_nz = s.norm_fit(self.log_smp_nz)[0] self.bfe_nz = np.exp(self.log_bfe_nz) self.info['estimators']['log_mean_sampled_nz'] = self.log_bfe_nz else: self.log_smp_nz = self.info['log_sampled_nz_meta_data'] self.smp_nz = np.exp(self.log_smp_nz) self.log_bfe_nz = self.info['estimators']['log_mean_sampled_nz'] self.bfe_nz = np.exp(self.log_smp_nz) # if vb: # plots.plot_samples(self.info, self.plot_dir) return self.log_smp_nz
def __init__(self, catalog, hyperprior, truth=None, loc='.', prepend='', vb=False): """ An object representing the redshift density function (normalized redshift distribution function) Parameters ---------- catalog: chippr.catalog object dict containing bin endpoints, interim prior bin values, and interim posterior PDF bin values hyperprior: chippr.mvn object multivariate Gaussian distribution for hyperprior distribution truth: chippr.gmix object, optional true redshift density function expressed as univariate Gaussian mixture loc: string, optional directory into which to save results and plots made along the way prepend: str, optional prepend string to file names vb: boolean, optional True to print progress messages to stdout, False to suppress """ self.info = {} self.add_text = prepend + '_' self.bin_ends = np.array(catalog['bin_ends']) self.bin_range = self.bin_ends[:-1] - self.bin_ends[0] self.bin_mids = (self.bin_ends[1:] + self.bin_ends[:-1]) / 2. self.bin_difs = self.bin_ends[1:] - self.bin_ends[:-1] self.log_bin_difs = u.safe_log(self.bin_difs) self.n_bins = len(self.bin_mids) self.info['bin_ends'] = self.bin_ends self.log_int_pr = np.array(catalog['log_interim_prior']) self.int_pr = np.exp(self.log_int_pr) self.info['log_interim_prior'] = self.log_int_pr self.log_pdfs = np.array(catalog['log_interim_posteriors']) self.pdfs = np.exp(self.log_pdfs) self.n_pdfs = len(self.log_pdfs) self.info['log_interim_posteriors'] = self.log_pdfs if vb: print( str(self.n_bins) + ' bins, ' + str(len(self.log_pdfs)) + ' interim posterior PDFs') self.hyper_prior = hyperprior self.truth = truth self.info['truth'] = None if self.truth is not None: self.info['truth'] = {} self.tru_nz = np.zeros(self.n_bins) self.fine_zs = [] self.fine_nz = [] for b in range(self.n_bins): fine_z = np.linspace(self.bin_ends[b], self.bin_ends[b + 1], self.n_bins) self.fine_zs.extend(fine_z) fine_dz = (self.bin_ends[b + 1] - self.bin_ends[b]) / self.n_bins fine_n = self.truth.evaluate(fine_z) self.fine_nz.extend(fine_n) coarse_nz = np.sum(fine_n) * fine_dz self.tru_nz[b] += coarse_nz self.tru_nz /= np.dot(self.tru_nz, self.bin_difs) self.log_tru_nz = u.safe_log(self.tru_nz) self.info['log_tru_nz'] = self.log_tru_nz self.info['truth']['z_grid'] = np.array(self.fine_zs) self.info['truth']['nz_grid'] = np.array(self.fine_nz) self.info['estimators'] = {} self.info['stats'] = {} self.dir = loc self.data_dir = os.path.join(loc, 'data') self.plot_dir = os.path.join(loc, 'plots') if not os.path.exists(self.plot_dir): os.makedirs(self.plot_dir) self.res_dir = os.path.join(loc, 'results') if not os.path.exists(self.res_dir): os.makedirs(self.res_dir) return
def make_probs(self, vb=False): """ Makes the continuous 2D probability distribution over z_spec, z_phot Parameters ---------- vb: boolean print progress to stdout? Returns ------- Notes ----- only one outlier population at a time for now """ hor_funcs = [ discrete(np.array([self.z_all[kk], self.z_all[kk + 1]]), np.array([1.])) for kk in range(self.n_tot) ] # x_alt = self._make_bias(self.z_all) x_alt = self._make_bias(self.z_fine) # should sigmas be proportional to z_true or bias*(1+z_true)? sigmas = self._make_scatter(x_alt) vert_funcs = [gauss(x_alt[kk], sigmas[kk]) for kk in range(self.n_tot)] # grid_amps = self.truth.evaluate(x_vals) # # grid_means, grid_amps, uniform_lfs = self._setup_prob_space(true_func) # # pdf_means = self._make_bias(grid_means) # WILL REFACTOR THIS TO ADD SUPPORT FOR MULTIPLE OUTLIER POPULATIONS if self.params['catastrophic_outliers'] != '0': frac = self.params['outlier_fraction'] rel_fracs = np.array([frac, 1. - frac]) uniform_lf = discrete( np.array([self.bin_ends[0], self.bin_ends[-1]]), np.array([1.])) if self.params['catastrophic_outliers'] == 'uniform': # use_frac = np.max((0., frac-0.01)) grid_funcs = [ gmix(rel_fracs, [uniform_lf, vert_funcs[kk]], limits=(self.bin_ends[0], self.bin_ends[-1])) for kk in range(self.n_tot) ] else: outlier_lf = gauss(self.params['outlier_mean'], self.params['outlier_sigma']**2) # in_amps = np.ones(self.n_tot) if self.params['catastrophic_outliers'] == 'template': grid_funcs = [ gmix(rel_fracs, [outlier_lf, vert_funcs[kk]], limits=(self.bin_ends[0], self.bin_ends[-1])) for kk in range(self.n_tot) ] # out_amps = uniform_lf.pdf(grid_means) elif self.params['catastrophic_outliers'] == 'training': full_pdf = np.exp(u.safe_log(outlier_lf.pdf(self.z_fine))) intermediate = np.dot(full_pdf, np.ones(self.n_tot) * self.dz_fine) full_pdf = full_pdf / intermediate[np.newaxis] # flat_pdf = np.exp(u.safe_log(uniform_lf.pdf(self.z_fine))) # flat_pdf = flat_pdf / np.dot(flat_pdf, self.dz_fine) # items = np.array([vert_funcs[kk].pdf(self.z_fine[kk]) for kk in range(self.n_tot)]) fracs = np.array([full_pdf, np.ones(self.n_tot)]).T fracs = fracs * np.array([frac, 1. - frac])[np.newaxis, :] grid_funcs = [ gmix(fracs[kk], [uniform_lf, vert_funcs[kk]], limits=(self.bin_ends[0], self.bin_ends[-1])) for kk in range(self.n_tot) ] # out_funcs = [multi_dist([uniform_lfs[kk], uniform_lf]) for kk in range(self.n_tot)] # out_amps = self.outlier_lf.pdf(grid_means) # out_amps /= np.dot(out_amps, self.bin_difs_fine) # in_amps *= (1. - self.params['outlier_fraction']) # out_amps *= self.params['outlier_fraction'] # try: # test_out_frac = np.dot(out_amps, self.bin_difs_fine) # assert np.isclose(test_out_frac, self.params['outlier_fraction']) # except: # print('outlier fraction not normalized: '+str(test_out_frac)) # grid_funcs = [gmix(np.array([in_amps[kk], out_amps[kk]]), [grid_funcs[kk], out_funcs[kk]]) for kk in range(self.n_tot)] # np.append(grid_means, [self.params['outlier_mean'], self.uniform_lf.sample_one()]) else: grid_funcs = vert_funcs # true n(z) in z_spec, uniform in z_phot # grid_amps *= true_func.evaluate(grid_means) p_space = [ multi_dist([hor_funcs[kk], grid_funcs[kk]]) for kk in range(self.n_tot) ] return p_space
def create(self, truth, int_pr, N=d.n_gals, vb=False): """ Function creating a catalog of interim posterior probability distributions, will split this up into helper functions Parameters ---------- truth: chippr.gmix object or chippr.gauss object or chippr.discrete object true redshift distribution object int_pr: chippr.gmix object or chippr.gauss object or chippr.discrete object interim prior distribution object vb: boolean, optional True to print progress messages to stdout, False to suppress Returns ------- self.cat: dict dictionary comprising catalog information """ self.N = 10**N self.N_range = range(self.N) self.truth = truth self.proc_bins() # samps_prep = np.empty((2, self.N)) # samps_prep[0] = self.truth.sample(self.N) prob_components = self.make_probs() hor_amps = self.truth.evaluate(self.z_fine) * self.bin_difs_fine self.pspace_draw = gmix(hor_amps, prob_components) if vb: plots.plot_prob_space(self.z_fine, self.pspace_draw, plot_loc=self.plot_dir, prepend=self.cat_name + 'draw_') # self.prob_space = self.make_probs() # if vb: # plots.plot_prob_space(self.z_fine, self.prob_space, plot_loc=self.plot_dir, prepend=self.cat_name) ## next, sample discrete to get z_true, z_obs self.samps = self.pspace_draw.sample(self.N) self.cat['true_vals'] = self.samps if vb: plots.plot_true_histogram(self.samps.T[0], n_bins=(self.n_coarse, self.n_tot), plot_loc=self.plot_dir, prepend=self.cat_name) ## then literally take slices (evaluate at constant z_phot) #self.obs_lfs /= np.sum(self.obs_lfs, axis=1)[:, np.newaxis] * self.dz_fine self.int_pr = int_pr int_pr_fine = np.array([self.int_pr.pdf(self.z_fine)]) self.pspace_eval = gmix(int_pr_fine, prob_components) if vb: plots.plot_prob_space(self.z_fine, self.pspace_eval, plot_loc=self.plot_dir, prepend=self.cat_name + 'eval_') self.obs_lfs = self.evaluate_lfs(self.pspace_eval) # truth_fine = self.truth.pdf(self.z_fine) # # pfs_fine = self.obs_lfs * int_pr_fine[np.newaxis, :] / truth_fine[np.newaxis, :] pfs_coarse = self.coarsify(self.obs_lfs) int_pr_coarse = self.coarsify(int_pr_fine) if vb: # plots.plot_scatter(self.samps, self.obs_lfs, self.z_fine, plot_loc=self.plot_dir, prepend=self.cat_name) plots.plot_mega_scatter(self.samps, self.obs_lfs, self.z_fine, self.bin_ends, truth=[self.z_fine, hor_amps], plot_loc=self.plot_dir, prepend=self.cat_name, int_pr=[self.z_fine, int_pr_fine[0]]) self.cat['bin_ends'] = self.bin_ends self.cat['log_interim_prior'] = u.safe_log(int_pr_coarse[0]) self.cat['log_interim_posteriors'] = u.safe_log(pfs_coarse) return self.cat