Exemplo n.º 1
0
def calculate_kld(pe, qe, vb=False):
    """
    Calculates the Kullback-Leibler Divergence between two PDFs.

    Parameters
    ----------
    pe: numpy.ndarray, float
        probability distribution evaluated on a grid whose distance from `q`
        will be calculated.
    qe: numpy.ndarray, float
        probability distribution evaluated on a grid whose distance to `p` will
        be calculated.
    vb: boolean
        report on progress to stdout?

    Returns
    -------
    Dpq: float
        the value of the Kullback-Leibler Divergence from `q` to `p`
    """
    # Normalize the evaluations, so that the integrals can be done
    # (very approximately!) by simple summation:
    pn = pe / np.sum(pe)
    qn = qe / np.sum(qe)
    # Compute the log of the normalized PDFs
    logp = u.safe_log(pn)
    logq = u.safe_log(qn)
    # Calculate the KLD from q to p
    Dpq = np.sum(pn * (logp - logq))
    return Dpq
Exemplo n.º 2
0
    def calculate_mexp(self, vb=False):
        """
        Calculates the marginalized expected value estimator of the redshift
        density function

        Parameters
        ----------
        vb: boolean, optional
            True to print progress messages to stdout, False to suppress

        Returns
        -------
        log_exp_nz: ndarray, float
            array of logged redshift density function bin values
        """
        if 'log_mexp_nz' not in self.info['estimators']:
            expprep = [
                sum(z) for z in self.bin_mids * self.pdfs * self.bin_difs
            ]
            self.exp_nz = np.zeros(self.n_bins)
            for z in expprep:
                for k in range(self.n_bins):
                    if z > self.bin_ends[k] and z < self.bin_ends[k + 1]:
                        self.exp_nz[k] += 1.
            self.exp_nz /= self.bin_difs * self.n_pdfs
            self.log_exp_nz = u.safe_log(self.exp_nz)
            self.info['estimators']['log_mexp_nz'] = self.log_exp_nz
        else:
            self.log_exp_nz = self.info['estimators']['log_mexp_nz']
            self.exp_nz = np.exp(self.log_exp_nz)

        return self.log_exp_nz
Exemplo n.º 3
0
    def evaluate_log_hyper_likelihood(self, log_nz):
        """
        Function to evaluate log hyperlikelihood

        Parameters
        ----------
        log_nz: numpy.ndarray, float
            vector of logged redshift density bin values at which to evaluate
            the hyperlikelihood

        Returns
        -------
        log_hyper_likelihood: float
            log likelihood probability associated with parameters in log_nz
        """
        nz = np.exp(log_nz)
        norm_nz = nz / np.dot(nz, self.bin_difs)

        # testing whether the norm step is still necessary
        hyper_lfs = np.sum(norm_nz[None, :] * self.pdfs /
                           self.int_pr[None, :] * self.bin_difs,
                           axis=1)
        log_hyper_likelihood = np.sum(u.safe_log(hyper_lfs)) - np.log(
            np.dot(norm_nz, self.bin_difs))

        # this used to work...
        # log_hyper_likelihood = np.dot(np.exp(log_nz + self.precomputed), self.bin_difs)
        return log_hyper_likelihood
Exemplo n.º 4
0
    def calculate_mmap(self, vb=False):
        """
        Calculates the marginalized maximum a posteriori estimator of the
        redshift density function

        Parameters
        ----------
        vb: boolean, optional
            True to print progress messages to stdout, False to suppress

        Returns
        -------
        log_map_nz: ndarray, float
            array of logged redshift density function bin values
        """
        if 'log_mmap_nz' not in self.info['estimators']:
            self.map_nz = np.zeros(self.n_bins)
            mappreps = [np.argmax(l) for l in self.log_pdfs]
            for m in mappreps:
                self.map_nz[m] += 1.
            self.map_nz /= self.bin_difs[m] * self.n_pdfs
            self.log_map_nz = u.safe_log(self.map_nz)
            self.info['estimators']['log_mmap_nz'] = self.log_map_nz
        else:
            self.log_map_nz = self.info['estimators']['log_mmap_nz']
            self.map_nz = np.exp(self.log_map_nz)

        return self.log_map_nz
Exemplo n.º 5
0
    def calculate_mmle(self, start, vb=False, no_data=0, no_prior=0):
        """
        Calculates the marginalized maximum likelihood estimator of the
        redshift density function

        Parameters
        ----------
        start: numpy.ndarray, float
            array of log redshift density function bin values at which to begin
            optimization
        vb: boolean, optional
            True to print progress messages to stdout, False to suppress
        no_data: boolean, optional
            True to exclude data contribution to hyperposterior
        no_prior: boolean, optional
            True to exclude prior contribution to hyperposterior

        Returns
        -------
        log_mle_nz: numpy.ndarray, float
            array of logged redshift density function bin values maximizing
            hyperposterior
        """
        # self.precomputed = self.precompute()
        if 'log_mmle_nz' not in self.info['estimators']:
            log_mle = self.optimize(start,
                                    no_data=no_data,
                                    no_prior=no_prior,
                                    vb=vb)
            mle_nz = np.exp(log_mle)
            self.mle_nz = mle_nz / np.dot(mle_nz, self.bin_difs)
            self.log_mle_nz = u.safe_log(self.mle_nz)
            self.info['estimators']['log_mmle_nz'] = self.log_mle_nz
        else:
            self.log_mle_nz = self.info['estimators']['log_mmle_nz']
            self.mle_nz = np.exp(self.log_mle_nz)

        return self.log_mle_nz
Exemplo n.º 6
0
def plot_prob_space(z_grid, p_space, plot_loc='', prepend='', plot_name='prob_space.png'):
    """
    Plots the 2D probability space of z_spec, z_phot

    Parameters
    ----------
    p_space: numpy.ndarray, float
        probabilities on the grid
    z_grid: numpy.ndarray, float
        fine grid of redshifts
    plot_loc: string, optional
        location in which to store plot
    plot_name: string, optional
        filename for plot
    prepend: str, optional
        prepend string to plot name
    """
    pu.set_up_plot()
    f = plt.figure(figsize=(5, 5))
    plt.subplot(1, 1, 1)
    grid_len = len(z_grid)
    grid_range = range(grid_len)
    # to_plot = u.safe_log(p_space.evaluate(all_points.reshape((grid_len**2, 2))).reshape((grid_len, grid_len)))
    # to_plot.reshape((len(z_grid), len(z_grid)))
    # plt.pcolormesh(z_grid, z_grid, to_plot, cmap='viridis')
    all_points = np.array([[(z_grid[kk], z_grid[jj]) for kk in grid_range] for jj in grid_range])
    orig_shape = np.shape(all_points)
    # all_vals = np.array([[p_space.evaluate_one(np.array([z_grid[jj], z_grid[kk]])) for jj in range(len(z_grid))] for kk in range(len(z_grid))])
    all_vals = p_space.evaluate(all_points.reshape((orig_shape[0]*orig_shape[1], orig_shape[2]))).reshape((orig_shape[0], orig_shape[1]))
    plt.pcolormesh(z_grid, z_grid, u.safe_log(all_vals), cmap='viridis')
    plt.plot(z_grid, z_grid, color='k')
    plt.colorbar()
    plt.xlabel(r'$z_{\mathrm{true}}$')
    plt.ylabel(r'$\mathrm{``data"}$')#z_{\mathrm{phot}}$')
    plt.axis([z_grid[0], z_grid[-1], z_grid[0], z_grid[-1]])
    f.savefig(os.path.join(plot_loc, prepend+plot_name), bbox_inches='tight', pad_inches = 0, dpi=d.dpi)
    return
Exemplo n.º 7
0
    def calculate_stacked(self, vb=False):
        """
        Calculates the stacked estimator of the redshift density function

        Parameters
        ----------
        vb: boolean, optional
            True to print progress messages to stdout, False to suppress

        Returns
        -------
        log_stk_nz: ndarray, float
            array of logged redshift density function bin values
        """
        if 'log_stacked_nz' not in self.info['estimators']:
            self.stk_nz = np.sum(self.pdfs, axis=0)
            self.stk_nz /= np.dot(self.stk_nz, self.bin_difs)
            self.log_stk_nz = u.safe_log(self.stk_nz)
            self.info['estimators']['log_stacked_nz'] = self.log_stk_nz
        else:
            self.log_stk_nz = self.info['estimators']['log_stacked_nz']
            self.stk_nz = np.exp(self.log_stk_nz)

        return self.log_stk_nz
Exemplo n.º 8
0
def plot_samples(info, plot_dir, prepend=''):
    """
    Plots a few random samples from the posterior distribution

    Parameters
    ----------
    info: dict
        dictionary of stored information from log_z_dens object
    plot_dir: string
        directory where plot should be stored
    prepend: str, optional
        prepend string to file names
    """
    pu.set_up_plot()

    f = plt.figure(figsize=(5, 10))
    sps = [f.add_subplot(2, 1, l + 1) for l in xrange(0, 2)]
    f.subplots_adjust(hspace=0, wspace=0)
    sps_log = sps[0]
    sps = sps[1]

    sps_log.set_xlim(info['bin_ends'][0], info['bin_ends'][-1])
    sps_log.set_ylabel(r'$\ln[n(z)]$')
    sps.set_xlim(info['bin_ends'][0], info['bin_ends'][-1])
    sps.set_xlabel(r'$z$')
    sps.set_ylabel(r'$n(z)$')
    sps.ticklabel_format(style='sci', axis='y')

    pu.plot_step(sps,
                 info['bin_ends'],
                 np.exp(info['log_interim_prior']),
                 w=w_int,
                 s=s_int,
                 a=a_int,
                 c=c_int,
                 d=d_int,
                 l=l_int + nz)
    pu.plot_step(sps_log,
                 info['bin_ends'],
                 info['log_interim_prior'],
                 w=w_int,
                 s=s_int,
                 a=a_int,
                 c=c_int,
                 d=d_int,
                 l=l_int + lnz)
    if info['truth'] is not None:
        sps.plot(info['truth']['z_grid'],
                 info['truth']['nz_grid'],
                 linewidth=w_tru,
                 alpha=a_tru,
                 color=c_tru,
                 label=l_tru + nz)
        sps_log.plot(info['truth']['z_grid'],
                     u.safe_log(info['truth']['nz_grid']),
                     linewidth=w_tru,
                     alpha=a_tru,
                     color=c_tru,
                     label=l_tru + lnz)

    (locs, scales) = s.norm_fit(info['log_sampled_nz_meta_data']['chains'])
    for k in range(len(info['bin_ends']) - 1):
        x_errs = [
            info['bin_ends'][k], info['bin_ends'][k], info['bin_ends'][k + 1],
            info['bin_ends'][k + 1]
        ]
        log_y_errs = [
            locs[k] - scales[k], locs[k] + scales[k], locs[k] + scales[k],
            locs[k] - scales[k]
        ]
        sps_log.fill(x_errs, log_y_errs, color='k', alpha=0.1, linewidth=0.)
        sps.fill(x_errs,
                 np.exp(log_y_errs),
                 color='k',
                 alpha=0.1,
                 linewidth=0.)
    shape = np.shape(info['log_sampled_nz_meta_data']['chains'])
    flat = info['log_sampled_nz_meta_data']['chains'].reshape(
        np.prod(shape[:-1]), shape[-1])
    random_samples = [
        np.random.randint(0, len(flat)) for i in range(d.plot_colors)
    ]
    for i in range(d.plot_colors):
        pu.plot_step(sps_log,
                     info['bin_ends'],
                     flat[random_samples[i]],
                     s=s_smp,
                     d=d_smp,
                     w=w_smp,
                     a=1.,
                     c=pu.colors[i])
        pu.plot_step(sps,
                     info['bin_ends'],
                     np.exp(flat[random_samples[i]]),
                     s=s_smp,
                     d=d_smp,
                     w=w_smp,
                     a=1.,
                     c=pu.colors[i])
    pu.plot_step(sps_log,
                 info['bin_ends'],
                 locs,
                 s=s_smp,
                 d=d_smp,
                 w=2.,
                 a=1.,
                 c='k',
                 l=l_bfe + lnz)
    pu.plot_step(sps,
                 info['bin_ends'],
                 np.exp(locs),
                 s=s_smp,
                 d=d_smp,
                 w=2.,
                 a=1.,
                 c='k',
                 l=l_bfe + nz)

    sps_log.legend(fontsize='x-small', loc='lower left')
    sps.set_xlabel('x')
    sps_log.set_ylabel('Log probability density')
    sps.set_ylabel('Probability density')
    f.savefig(os.path.join(plot_dir, prepend + 'samples.png'),
              bbox_inches='tight',
              pad_inches=0)

    return
Exemplo n.º 9
0
def plot_ivals(ivals, info, plot_dir, prepend=''):
    """
    Plots the initial values given to the sampler

    Parameters
    ----------
    ivals: np.ndarray, float
        (n_walkers, n_bins) array of initial values for sampler
    info: dict
        dictionary of stored information from log_z_dens object
    plot_dir: string
        location into which the plot will be saved
    prepend: str, optional
        prepend string to file names

    Returns
    -------
    f: matplotlib figure
        figure object
    """
    pu.set_up_plot()
    n_walkers = len(ivals)
    walkers = [np.random.randint(0, n_walkers) for i in range(d.plot_colors)]

    f = plt.figure(figsize=(10, 5))
    sps_samp = f.add_subplot(1, 2, 1)
    for i in range(d.plot_colors):
        pu.plot_step(sps_samp,
                     info['bin_ends'],
                     ivals[walkers[i]],
                     c=pu.colors[i])
    pu.plot_step(sps_samp,
                 info['bin_ends'],
                 info['log_interim_prior'],
                 w=w_int,
                 s=s_int,
                 a=a_int,
                 c=c_int,
                 d=d_int,
                 l=l_int + nz)
    if info['truth'] is not None:
        sps_samp.plot(info['truth']['z_grid'],
                      np.log(info['truth']['nz_grid']),
                      linewidth=w_tru,
                      alpha=a_tru,
                      color=c_tru,
                      label=l_tru + nz)
    sps_samp.set_xlabel(r'$z$')
    sps_samp.set_ylabel(r'$\ln\left[n(z)\right]$')

    sps_sum = f.add_subplot(1, 2, 2)
    bin_difs = info['bin_ends'][1:] - info['bin_ends'][:-1]
    ival_integrals = np.dot(np.exp(ivals), bin_difs)
    log_ival_integrals = u.safe_log(ival_integrals)
    sps_sum.hist(log_ival_integrals, color='k', normed=1)
    sps_sum.vlines(np.log(np.dot(np.exp(info['log_interim_prior']), bin_difs)),
                   0.,
                   1.,
                   linewidth=w_int,
                   linestyle=s_int,
                   alpha=a_int,
                   color=c_int,
                   dashes=d_int,
                   label=l_int + nz)
    sps_sum.vlines(np.mean(log_ival_integrals),
                   0.,
                   1.,
                   linewidth=w_bfe,
                   linestyle=s_bfe,
                   alpha=a_bfe,
                   color=c_bfe,
                   dashes=d_bfe,
                   label=l_bfe + lnz)

    sps_sum.set_xlabel(r'$\ln\left[\int n(z)dz\right]$')
    sps_sum.set_ylabel(r'$p\left(\ln\left[\int n(z)dz\right]\right)$')

    f.savefig(os.path.join(plot_dir, prepend + 'ivals.png'),
              bbox_inches='tight',
              pad_inches=0)

    return
Exemplo n.º 10
0
    def calculate_samples(self,
                          ivals,
                          n_accepted=d.n_accepted,
                          n_burned=d.n_burned,
                          vb=False,
                          n_procs=1,
                          no_data=0,
                          no_prior=0):
        """
        Calculates samples estimating the redshift density function

        Parameters
        ----------
        ivals: numpy.ndarray, float
            initial values of log n(z) for each walker
        n_accepted: int, optional
            log10 number of samples to accept per walker
        n_burned: int, optional
            log10 number of samples between tests of burn-in condition
        n_procs: int, optional
            number of processors to use, defaults to single-thread
        vb: boolean, optional
            True to print progress messages to stdout, False to suppress
        no_data: boolean, optional
            True to exclude data contribution to hyperposterior
        no_prior: boolean, optional
            True to exclude prior contribution to hyperposterior

        Returns
        -------
        log_samples_nz: ndarray, float
            array of sampled log redshift density function bin values
        """
        # self.precomputed = self.precompute()
        if 'log_mean_sampled_nz' not in self.info['estimators']:
            self.n_walkers = len(ivals)
            if no_data:

                def distribution(log_nz):
                    return self.evaluate_log_hyper_prior(log_nz)
            elif no_prior:

                def distribution(log_nz):
                    return self.evaluate_log_hyper_likelihood(log_nz)
            else:

                def distribution(log_nz):
                    return self.evaluate_log_hyper_posterior(log_nz)

            self.sampler = emcee.EnsembleSampler(self.n_walkers,
                                                 self.n_bins,
                                                 distribution,
                                                 threads=n_procs)
            self.burn_ins = 0
            if n_burned == 0:
                self.burning_in = False
            else:
                self.burning_in = True
            vals = ivals
            vals -= u.safe_log(
                np.sum(np.exp(ivals) * self.bin_difs[np.newaxis, :],
                       axis=1))[:, np.newaxis]
            if vb:
                plots.plot_ivals(vals,
                                 self.info,
                                 self.plot_dir,
                                 prepend=self.add_text)
                canvas = plots.set_up_burn_in_plots(self.n_bins,
                                                    self.n_walkers)
            full_chain = np.array([[vals[w]] for w in range(self.n_walkers)])
            while self.burning_in:
                if vb:
                    print('beginning sampling ' + str(self.burn_ins))
                burn_in_mcmc_outputs = self.sample(vals, 10**n_burned)
                chain = burn_in_mcmc_outputs['chains']
                burn_in_mcmc_outputs['chains'] -= u.safe_log(
                    np.sum(np.exp(chain) *
                           self.bin_difs[np.newaxis, np.newaxis, :],
                           axis=2))[:, :, np.newaxis]
                with open(
                        os.path.join(self.res_dir,
                                     'mcmc' + str(self.burn_ins) + '.p'),
                        'wb') as file_location:
                    cpkl.dump(burn_in_mcmc_outputs, file_location)
                full_chain = np.concatenate(
                    (full_chain, burn_in_mcmc_outputs['chains']), axis=1)
                if vb:
                    canvas = plots.plot_sampler_progress(canvas,
                                                         burn_in_mcmc_outputs,
                                                         full_chain,
                                                         self.burn_ins,
                                                         self.plot_dir,
                                                         prepend=self.add_text)
                self.burning_in = s.gr_test(full_chain)
                vals = np.array(
                    [item[-1] for item in burn_in_mcmc_outputs['chains']])
                self.burn_ins += 1

            mcmc_outputs = self.sample(vals, 10**n_accepted)
            chain = mcmc_outputs['chains']
            mcmc_outputs['chains'] -= u.safe_log(
                np.sum(np.exp(chain) *
                       self.bin_difs[np.newaxis, np.newaxis, :],
                       axis=2))[:, :, np.newaxis]
            full_chain = np.concatenate((full_chain, mcmc_outputs['chains']),
                                        axis=1)
            with open(os.path.join(self.res_dir, 'full_chain.p'),
                      'wb') as file_location:
                cpkl.dump(full_chain, file_location)

            self.log_smp_nz = mcmc_outputs['chains']
            self.smp_nz = np.exp(self.log_smp_nz)
            self.info['log_sampled_nz_meta_data'] = mcmc_outputs
            self.log_bfe_nz = s.norm_fit(self.log_smp_nz)[0]
            self.bfe_nz = np.exp(self.log_bfe_nz)
            self.info['estimators']['log_mean_sampled_nz'] = self.log_bfe_nz
        else:
            self.log_smp_nz = self.info['log_sampled_nz_meta_data']
            self.smp_nz = np.exp(self.log_smp_nz)
            self.log_bfe_nz = self.info['estimators']['log_mean_sampled_nz']
            self.bfe_nz = np.exp(self.log_smp_nz)

        # if vb:
        # plots.plot_samples(self.info, self.plot_dir)

        return self.log_smp_nz
Exemplo n.º 11
0
    def __init__(self,
                 catalog,
                 hyperprior,
                 truth=None,
                 loc='.',
                 prepend='',
                 vb=False):
        """
        An object representing the redshift density function (normalized
        redshift distribution function)

        Parameters
        ----------
        catalog: chippr.catalog object
            dict containing bin endpoints, interim prior bin values, and
            interim posterior PDF bin values
        hyperprior: chippr.mvn object
            multivariate Gaussian distribution for hyperprior distribution
        truth: chippr.gmix object, optional
            true redshift density function expressed as univariate Gaussian
            mixture
        loc: string, optional
            directory into which to save results and plots made along the way
        prepend: str, optional
            prepend string to file names
        vb: boolean, optional
            True to print progress messages to stdout, False to suppress
        """
        self.info = {}
        self.add_text = prepend + '_'

        self.bin_ends = np.array(catalog['bin_ends'])
        self.bin_range = self.bin_ends[:-1] - self.bin_ends[0]
        self.bin_mids = (self.bin_ends[1:] + self.bin_ends[:-1]) / 2.
        self.bin_difs = self.bin_ends[1:] - self.bin_ends[:-1]
        self.log_bin_difs = u.safe_log(self.bin_difs)
        self.n_bins = len(self.bin_mids)
        self.info['bin_ends'] = self.bin_ends

        self.log_int_pr = np.array(catalog['log_interim_prior'])
        self.int_pr = np.exp(self.log_int_pr)
        self.info['log_interim_prior'] = self.log_int_pr

        self.log_pdfs = np.array(catalog['log_interim_posteriors'])
        self.pdfs = np.exp(self.log_pdfs)
        self.n_pdfs = len(self.log_pdfs)
        self.info['log_interim_posteriors'] = self.log_pdfs

        if vb:
            print(
                str(self.n_bins) + ' bins, ' + str(len(self.log_pdfs)) +
                ' interim posterior PDFs')

        self.hyper_prior = hyperprior

        self.truth = truth
        self.info['truth'] = None
        if self.truth is not None:
            self.info['truth'] = {}
            self.tru_nz = np.zeros(self.n_bins)
            self.fine_zs = []
            self.fine_nz = []
            for b in range(self.n_bins):
                fine_z = np.linspace(self.bin_ends[b], self.bin_ends[b + 1],
                                     self.n_bins)
                self.fine_zs.extend(fine_z)
                fine_dz = (self.bin_ends[b + 1] -
                           self.bin_ends[b]) / self.n_bins
                fine_n = self.truth.evaluate(fine_z)
                self.fine_nz.extend(fine_n)
                coarse_nz = np.sum(fine_n) * fine_dz
                self.tru_nz[b] += coarse_nz
            self.tru_nz /= np.dot(self.tru_nz, self.bin_difs)
            self.log_tru_nz = u.safe_log(self.tru_nz)
            self.info['log_tru_nz'] = self.log_tru_nz
            self.info['truth']['z_grid'] = np.array(self.fine_zs)
            self.info['truth']['nz_grid'] = np.array(self.fine_nz)

        self.info['estimators'] = {}
        self.info['stats'] = {}

        self.dir = loc
        self.data_dir = os.path.join(loc, 'data')
        self.plot_dir = os.path.join(loc, 'plots')
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)
        self.res_dir = os.path.join(loc, 'results')
        if not os.path.exists(self.res_dir):
            os.makedirs(self.res_dir)

        return
Exemplo n.º 12
0
    def make_probs(self, vb=False):
        """
        Makes the continuous 2D probability distribution over z_spec, z_phot

        Parameters
        ----------
        vb: boolean
            print progress to stdout?

        Returns
        -------

        Notes
        -----
        only one outlier population at a time for now
        """
        hor_funcs = [
            discrete(np.array([self.z_all[kk], self.z_all[kk + 1]]),
                     np.array([1.])) for kk in range(self.n_tot)
        ]

        # x_alt = self._make_bias(self.z_all)

        x_alt = self._make_bias(self.z_fine)

        # should sigmas be proportional to z_true or bias*(1+z_true)?
        sigmas = self._make_scatter(x_alt)

        vert_funcs = [gauss(x_alt[kk], sigmas[kk]) for kk in range(self.n_tot)]

        # grid_amps = self.truth.evaluate(x_vals)
        #
        # grid_means, grid_amps, uniform_lfs = self._setup_prob_space(true_func)
        #
        # pdf_means = self._make_bias(grid_means)

        # WILL REFACTOR THIS TO ADD SUPPORT FOR MULTIPLE OUTLIER POPULATIONS
        if self.params['catastrophic_outliers'] != '0':
            frac = self.params['outlier_fraction']
            rel_fracs = np.array([frac, 1. - frac])
            uniform_lf = discrete(
                np.array([self.bin_ends[0], self.bin_ends[-1]]),
                np.array([1.]))
            if self.params['catastrophic_outliers'] == 'uniform':
                # use_frac = np.max((0., frac-0.01))
                grid_funcs = [
                    gmix(rel_fracs, [uniform_lf, vert_funcs[kk]],
                         limits=(self.bin_ends[0], self.bin_ends[-1]))
                    for kk in range(self.n_tot)
                ]
            else:
                outlier_lf = gauss(self.params['outlier_mean'],
                                   self.params['outlier_sigma']**2)
                # in_amps = np.ones(self.n_tot)
                if self.params['catastrophic_outliers'] == 'template':
                    grid_funcs = [
                        gmix(rel_fracs, [outlier_lf, vert_funcs[kk]],
                             limits=(self.bin_ends[0], self.bin_ends[-1]))
                        for kk in range(self.n_tot)
                    ]
                    # out_amps = uniform_lf.pdf(grid_means)
                elif self.params['catastrophic_outliers'] == 'training':
                    full_pdf = np.exp(u.safe_log(outlier_lf.pdf(self.z_fine)))
                    intermediate = np.dot(full_pdf,
                                          np.ones(self.n_tot) * self.dz_fine)
                    full_pdf = full_pdf / intermediate[np.newaxis]
                    # flat_pdf = np.exp(u.safe_log(uniform_lf.pdf(self.z_fine)))
                    # flat_pdf = flat_pdf / np.dot(flat_pdf, self.dz_fine)
                    # items = np.array([vert_funcs[kk].pdf(self.z_fine[kk]) for kk in range(self.n_tot)])
                    fracs = np.array([full_pdf, np.ones(self.n_tot)]).T
                    fracs = fracs * np.array([frac, 1. - frac])[np.newaxis, :]
                    grid_funcs = [
                        gmix(fracs[kk], [uniform_lf, vert_funcs[kk]],
                             limits=(self.bin_ends[0], self.bin_ends[-1]))
                        for kk in range(self.n_tot)
                    ]
                    # out_funcs = [multi_dist([uniform_lfs[kk], uniform_lf]) for kk in range(self.n_tot)]
                    # out_amps = self.outlier_lf.pdf(grid_means)

                # out_amps /= np.dot(out_amps, self.bin_difs_fine)
                # in_amps *= (1. - self.params['outlier_fraction'])
                # out_amps *= self.params['outlier_fraction']
                # try:
                #     test_out_frac = np.dot(out_amps, self.bin_difs_fine)
                #     assert np.isclose(test_out_frac, self.params['outlier_fraction'])
                # except:
                #     print('outlier fraction not normalized: '+str(test_out_frac))
                # grid_funcs = [gmix(np.array([in_amps[kk], out_amps[kk]]), [grid_funcs[kk], out_funcs[kk]]) for kk in range(self.n_tot)]
                # np.append(grid_means, [self.params['outlier_mean'], self.uniform_lf.sample_one()])
        else:
            grid_funcs = vert_funcs
        # true n(z) in z_spec, uniform in z_phot
        # grid_amps *= true_func.evaluate(grid_means)
        p_space = [
            multi_dist([hor_funcs[kk], grid_funcs[kk]])
            for kk in range(self.n_tot)
        ]

        return p_space
Exemplo n.º 13
0
    def create(self, truth, int_pr, N=d.n_gals, vb=False):
        """
        Function creating a catalog of interim posterior probability
        distributions, will split this up into helper functions

        Parameters
        ----------
        truth: chippr.gmix object or chippr.gauss object or chippr.discrete
        object
            true redshift distribution object
        int_pr: chippr.gmix object or chippr.gauss object or chippr.discrete
        object
            interim prior distribution object
        vb: boolean, optional
            True to print progress messages to stdout, False to suppress

        Returns
        -------
        self.cat: dict
            dictionary comprising catalog information
        """
        self.N = 10**N
        self.N_range = range(self.N)
        self.truth = truth

        self.proc_bins()

        # samps_prep  = np.empty((2, self.N))
        # samps_prep[0] = self.truth.sample(self.N)

        prob_components = self.make_probs()
        hor_amps = self.truth.evaluate(self.z_fine) * self.bin_difs_fine
        self.pspace_draw = gmix(hor_amps, prob_components)
        if vb:
            plots.plot_prob_space(self.z_fine,
                                  self.pspace_draw,
                                  plot_loc=self.plot_dir,
                                  prepend=self.cat_name + 'draw_')

        # self.prob_space = self.make_probs()
        # if vb:
        #     plots.plot_prob_space(self.z_fine, self.prob_space, plot_loc=self.plot_dir, prepend=self.cat_name)

        ## next, sample discrete to get z_true, z_obs
        self.samps = self.pspace_draw.sample(self.N)
        self.cat['true_vals'] = self.samps
        if vb:
            plots.plot_true_histogram(self.samps.T[0],
                                      n_bins=(self.n_coarse, self.n_tot),
                                      plot_loc=self.plot_dir,
                                      prepend=self.cat_name)

        ## then literally take slices (evaluate at constant z_phot)
        #self.obs_lfs /= np.sum(self.obs_lfs, axis=1)[:, np.newaxis] * self.dz_fine

        self.int_pr = int_pr
        int_pr_fine = np.array([self.int_pr.pdf(self.z_fine)])
        self.pspace_eval = gmix(int_pr_fine, prob_components)
        if vb:
            plots.plot_prob_space(self.z_fine,
                                  self.pspace_eval,
                                  plot_loc=self.plot_dir,
                                  prepend=self.cat_name + 'eval_')

        self.obs_lfs = self.evaluate_lfs(self.pspace_eval)

        # truth_fine = self.truth.pdf(self.z_fine)
        #
        # pfs_fine = self.obs_lfs * int_pr_fine[np.newaxis, :] / truth_fine[np.newaxis, :]
        pfs_coarse = self.coarsify(self.obs_lfs)
        int_pr_coarse = self.coarsify(int_pr_fine)

        if vb:
            # plots.plot_scatter(self.samps, self.obs_lfs, self.z_fine, plot_loc=self.plot_dir, prepend=self.cat_name)
            plots.plot_mega_scatter(self.samps,
                                    self.obs_lfs,
                                    self.z_fine,
                                    self.bin_ends,
                                    truth=[self.z_fine, hor_amps],
                                    plot_loc=self.plot_dir,
                                    prepend=self.cat_name,
                                    int_pr=[self.z_fine, int_pr_fine[0]])

        self.cat['bin_ends'] = self.bin_ends
        self.cat['log_interim_prior'] = u.safe_log(int_pr_coarse[0])
        self.cat['log_interim_posteriors'] = u.safe_log(pfs_coarse)

        return self.cat