def _lngrid_from_trace(self, trace, make_plot):
        extents = {'x0': self.par_limits[1], 'x1': self.par_limits[2]}

        samps = MCSamples(samples=trace, names=['x0', 'x1'], ranges=extents)
        density = samps.get2DDensity('x0', 'x1')

        # set up the grid on which to evaluate the likelihood
        x_bins, y_bins = self.get_x0_x1

        xx, yy = np.meshgrid(x_bins, y_bins)
        pos = np.vstack([xx.ravel(), yy.ravel()]).T

        # Evalaute density on a grid
        prob = np.array([density.Prob(*ele) for ele in pos])
        prob[prob < 0] = 1e-50
        ln_prob = np.log(prob)
        ln_prob -= ln_prob.max()
        ln_prob = ln_prob.reshape(xx.shape)

        if make_plot:
            plt.pcolormesh(x_bins, y_bins, ln_prob)
            plt.colorbar()
            plt.clim(0, -5)
        return ln_prob
def mcmcSkewer(bundleObj,
               logdef=3,
               binned=False,
               niter=2500,
               do_mcmc=True,
               return_sampler=False,
               evalgrid=True,
               in_axes=None,
               viz=False,
               VERBOSITY=False,
               seed=None,
               truths=[0.002, 3.8]):
    """
    Script to fit simple flux model on each restframe wavelength skewer

    Parameters:
    -----------
        bundleObj : A list of [z, f, ivar] with the skewer_index
        logdef : Which model to use
        niter : The number of iterations to run the mcmc (40% for burn-in)
        do_mcmc : Flag whether to perform mcmc
        plt_pts : Plot the data along with best fit from scipy and mcmc
        return_sampler : Whether to return the raw sampler  without flatchaining
        triangle : Display triangle plot of the parameters
        evalgrid : Whether to compute loglikelihood on a specified grid
        in_axes : axes over which to draw the plots
        xx_viz : draw marginalized contour in modifed space
        VERBOSITY : print extra information
        seed : how to seed the random state
        truths : used with logdef=4, best-fit values of tau0 and gamma

    Returns:
        mcmc_chains if return_sampler, else None
    """

    z, f, ivar = bundleObj[0].T

    ind = (ivar > 0) & (np.isfinite(f))
    z, f, sigma = z[ind], f[ind], 1.0 / np.sqrt(ivar[ind])
    # -------------------------------------------------------------------------
    # continuum flux estimate given a value of (tau0, gamma)
    if logdef == 4:
        if VERBOSITY:
            print('Continuum estimates using optical depth parameters:',
                  truths)
        chisq4 = lambda *args: -outer(*truths)(*args)

        opt_res = minimize(chisq4,
                           1.5,
                           args=(z, f, sigma),
                           method='Nelder-Mead')
        return opt_res['x']

    if VERBOSITY:
        print('Carrying analysis for skewer', bundleObj[1])

    if logdef == 1:
        nll, names, labels, guess = chisq1, names1, labels1, guess1
        ndim, kranges, lnlike = 4, kranges1, lnlike1

    elif logdef == 2:
        nll, names, labels, guess = chisq2, names2, labels2, guess2
        ndim, kranges, lnlike = 5, kranges2, lnlike2

    elif logdef == 3:
        nll, names, labels, guess = chisq3, names3, labels3, guess3
        ndim, kranges, lnlike = 3, kranges3, lnlike3

    # Try to fit with scipy optimize routine
    opt_res = minimize(nll, guess, args=(z, f, sigma), method='Nelder-Mead')
    print('Scipy optimize results:')
    print('Success =', opt_res['success'], 'params =', opt_res['x'], '\n')

    if viz:
        if in_axes is None:
            fig, in_axes = plt.subplots(1)
        in_axes.errorbar(z, f, sigma, fmt='o', color='gray', alpha=0.2)
        in_axes.plot(
            zline, opt_res['x'][0] * np.exp(-np.exp(opt_res['x'][1]) *
                                            (1 + zline)**opt_res['x'][2]))

    if binned:
        mu = binned_statistic(z, f, bins=binx).statistic
        sig = binned_statistic(z, f, bins=binx, statistic=sig_func).statistic

        ixs = sig > 0
        z, f, sigma = centers[ixs], mu[ixs], sig[ixs]

        if viz:
            in_axes.errorbar(z, f, sigma, fmt='o', color='r')

        nll, names, labels, guess = lsq, names3, labels3, guess3
        ndim, kranges, lnlike = 3, kranges3, simpleln

    # --------------------------------------------------------------------------
    if do_mcmc:
        np.random.seed()

        nwalkers = 100
        p0 = [guess + 1e-4 * np.random.randn(ndim) for i in range(nwalkers)]

        # configure the sampler
        sampler = emcee.EnsembleSampler(nwalkers,
                                        ndim,
                                        lnlike,
                                        args=(z, f, sigma))

        # burn-in time - Is this enough?
        p0, __, __ = sampler.run_mcmc(p0, 500)
        sampler.reset()

        # Production step
        sampler.run_mcmc(p0, niter)
        print("Burn-in and production completed \n")

        if return_sampler:
            return sampler.chain
        else:
            # pruning 40 percent of the samples as extra burn-in
            lInd = int(niter * 0.4)
            samps = sampler.chain[:, lInd:, :].reshape((-1, ndim))

            # using percentiles as confidence intervals
            CenVal = np.median(samps, axis=0)

            # print BIC at the best estimate point, BIC = - 2 * ln(L_0) + k ln(n)
            print('CHISQ_R', -2 * lnlike(CenVal, z, f, sigma) / (len(z) - 3))
            print('BIC:',
                  -2 * lnlike(CenVal, z, f, sigma) + ndim * np.log(len(z)))

            # Rotate the points to the other basis and 1D estimates
            # and write them to the file

            # Format : center, top error, bottom error
            tg_est = list(
                map(lambda v: (v[1], v[2] - v[1], v[1] - v[0]),
                    zip(*np.percentile(samps, [16, 50, 84], axis=0))))

            xx = xfm(samps[:, 1:], shift, tilt, direction='up')
            xx_est = list(
                map(lambda v: (v[1], v[2] - v[1], v[1] - v[0]),
                    zip(*np.percentile(xx, [16, 50, 84], axis=0))))

            f_name2 = 'tg_est_' + str(bundleObj[1]) + '.dat'
            np.savetxt(f_name2, tg_est)
            f_name3 = 'xx_est_' + str(bundleObj[1]) + '.dat'
            np.savetxt(f_name3, xx_est)

            if viz:
                in_axes.plot(
                    zline, CenVal[0] * np.exp(-np.exp(CenVal[1]) *
                                              (1 + zline)**CenVal[2]), '-g')

            # instantiate a getdist object
            MC = MCSamples(samples=samps,
                           names=names,
                           labels=labels,
                           ranges=kranges)

            # MODIFY THIS TO BE PRETTIER
            if viz:
                g = plots.getSubplotPlotter()
                g.triangle_plot(MC)

            # Evaluate the pdf on a rotated grid for better estimation
            if evalgrid:
                print('Evaluating on the grid specified \n')
                pdist = MC.get2DDensity('t0', 'gamma')

                # Evalaute density on a grid
                pgrid = np.array([pdist.Prob(*ele) for ele in modPos])
                # Prune to remove negative densities
                pgrid[pgrid < 0] = 1e-50

                # Convert to logLikelihood
                logP = np.log(pgrid)
                logP -= logP.max()
                logP = logP.reshape(x0.shape)

                # Visualize the contour in modified space per skewer
                if viz:
                    fig, ax2 = plt.subplots(1)
                    ax2.contour(x0,
                                x1,
                                cts(logP),
                                levels=[
                                    0.683,
                                    0.955,
                                ],
                                colors='k')
                    ax2.axvline(xx_est[0][0] + xx_est[0][1])
                    ax2.axvline(xx_est[0][0] - xx_est[0][2])
                    ax2.axhline(xx_est[1][0] + xx_est[1][1])
                    ax2.axhline(xx_est[1][0] - xx_est[1][2])
                    ax2.set_xlabel(r'$x_0$')
                    ax2.set_ylabel(r'$x_1$')
                    plt.show()

                # fileName1: the log-probability evaluated in the tilted grid
                f_name1 = 'gridlnlike_' + str(bundleObj[1]) + '.dat'
                np.savetxt(f_name1, logP)