Beispiel #1
0
    def __init__(self,
                 likelihood_evaluator,
                 nwalkers,
                 transd=False,
                 pool=None,
                 likelihood_call=None,
                 update_interval=None):

        try:
            import kombine
        except ImportError:
            raise ImportError("kombine is not installed.")

        if likelihood_call is None:
            likelihood_call = likelihood_evaluator

        # construct sampler for use in KombineSampler
        ndim = len(likelihood_evaluator.waveform_generator.variable_args)
        sampler = kombine.Sampler(nwalkers,
                                  ndim,
                                  likelihood_call,
                                  transd=transd,
                                  pool=pool,
                                  processes=pool.count)
        # initialize
        super(KombineSampler, self).__init__(sampler, likelihood_evaluator)
        self._nwalkers = nwalkers
        self.update_interval = update_interval
Beispiel #2
0
    def __init__(self,
                 likelihood_evaluator,
                 nwalkers=0,
                 ndim=0,
                 transd=False,
                 processes=None):

        try:
            import kombine
        except ImportError:
            raise ImportError("kombine is not installed.")

        # construct sampler for use in KombineSampler
        self._sampler = kombine.Sampler(nwalkers,
                                        ndim,
                                        likelihood_evaluator,
                                        transd=transd,
                                        processes=processes)

        # initialize
        super(KombineSampler, self).__init__(likelihood_evaluator)
Beispiel #3
0
    def __init__(self, model, nwalkers, transd=False,
                 pool=None, model_call=None,
                 update_interval=None):

        try:
            import kombine
        except ImportError:
            raise ImportError("kombine is not installed.")

        if model_call is None:
            model_call = model

        # construct sampler for use in KombineSampler
        ndim = len(model.variable_params)
        count = 1 if pool is None else pool.count
        sampler = kombine.Sampler(nwalkers, ndim, model_call,
                                  transd=transd, pool=pool,
                                  processes=count)
        # initialize
        super(KombineSampler, self).__init__(sampler, model)
        self._nwalkers = nwalkers
        self.update_interval = update_interval
Beispiel #4
0
    def __init__(self,
                 likelihood_evaluator,
                 nwalkers,
                 transd=False,
                 processes=None,
                 min_burn_in=None):
        try:
            import kombine
        except ImportError:
            raise ImportError("kombine is not installed.")

        # construct sampler for use in KombineSampler
        ndim = len(likelihood_evaluator.waveform_generator.variable_args)
        sampler = kombine.Sampler(nwalkers,
                                  ndim,
                                  likelihood_evaluator,
                                  transd=transd,
                                  processes=processes)
        # initialize
        super(KombineSampler, self).__init__(sampler,
                                             likelihood_evaluator,
                                             min_burn_in=min_burn_in)
        self._nwalkers = nwalkers
Beispiel #5
0
def kdes(self, p0=None, nsteps=3000, nwalks=None, tune=None, seed=None, linear=None, resume=False, verbose=False, debug=False):

    import pathos
    import kombine
    from grgrlib.patches import kombine_run_mcmc

    kombine.Sampler.run_mcmc = kombine_run_mcmc

    if not hasattr(self, 'ndim'):
        # if it seems to be missing, lets do it.
        # but without guarantee...
        self.prep_estim(load_R=True)

    if seed is None:
        seed = self.fdict['seed']

    np.random.seed(seed)

    if tune is None:
        self.tune = None

    if linear is None:
        linear = self.filter.name == 'KalmanFilter'

    if nwalks is None:
        nwalks = 120

    if 'description' in self.fdict.keys():
        self.description = self.fdict['description']

    if not use_cloudpickle:
        # globals are *evil*
        global lprob_global
    else:
        import cloudpickle as cpickle
        lprob_dump = cpickle.dumps(self.lprob)
        lprob_global = cpickle.loads(lprob_dump)

    def lprob(par): return lprob_global(par, linear, verbose)

    if self.pool:
        self.pool.clear()

    if debug:
        sampler = kombine.Sampler(nwalks, self.ndim, lprob)
    else:
        sampler = kombine.Sampler(
            nwalks, self.ndim, lprob, pool=self.pool)

        if self.pool:
            self.pool.close()

    if p0 is not None:
        pass
    elif resume:
        # should work, but not tested
        p0 = self.fdict['kdes_chain'][-1]
    else:
        p0 = get_par(self, 'best', asdict=False, full=True,
                     nsample=nwalks, verbose=verbose)

    if not verbose:
        np.warnings.filterwarnings('ignore')

    if not verbose:
        pbar = tqdm.tqdm(total=nsteps, unit='sample(s)', dynamic_ncols=True)

    if nsteps < 500:
        nsteps_burnin = nsteps
        nsteps_mcmc = 0
    elif nsteps < 1000:
        nsteps_burnin = 500
        nsteps_mcmc = nsteps - nsteps_burnin
    else:
        nsteps_mcmc = 500
        nsteps_burnin = nsteps - nsteps_mcmc

    tune = max(500, nsteps_burnin)

    p, post, q = sampler.burnin(
        p0, max_steps=nsteps_burnin, pbar=pbar, verbose=verbose)

    if nsteps_mcmc:
        p, post, q = sampler.run_mcmc(nsteps_mcmc, pbar=pbar)

    acls = np.ceil(
        2/np.mean(sampler.acceptance[-tune:], axis=0) - 1).astype(int)
    samples = np.concatenate(
        [sampler.chain[-tune::acl, c].reshape(-1, 2) for c, acl in enumerate(acls)])

    # samples = sampler.get_samples()

    kdes_chain = sampler.chain
    kdes_sample = samples.reshape(1, -1, self.ndim)

    self.kdes_chain = kdes_chain
    self.kdes_sample = kdes_sample
    self.fdict['tune'] = tune
    self.fdict['kdes_chain'] = kdes_chain
    self.fdict['kdes_sample'] = kdes_sample

    pbar.close()

    if not verbose:
        np.warnings.filterwarnings('default')

    log_probs = sampler.get_log_prob()[self.tune:]
    chain = sampler.get_chain()[self.tune:]
    chain = chain.reshape(-1, chain.shape[-1])

    arg_max = log_probs.argmax()
    mode_f = log_probs.flat[arg_max]
    mode_x = chain[arg_max]

    self.fdict['kombine_mode_x'] = mode_x
    self.fdict['kombine_mode_f'] = mode_f

    if 'mode_f' in self.fdict.keys() and mode_f < self.fdict['mode_f']:
        print('[kombine:]'.ljust(15, ' ') + " New mode of %s is below old mode of %s. Rejecting..." %
              (mode_f, self.fdict['mode_f']))
    else:
        self.fdict['mode_x'] = mode_x
        self.fdict['mode_f'] = mode_x

    self.sampler = sampler
    self.fdict['datetime'] = str(datetime.now())

    return
Beispiel #6
0
def main(apogee_id,
         index,
         n_walkers,
         n_steps,
         sampler_name,
         n_burnin=128,
         mpi=False,
         seed=42,
         overwrite=False):

    # MPI s***e
    pool = get_pool(mpi=mpi, loadbalance=True)
    # need load-balancing - see: https://groups.google.com/forum/#!msg/mpi4py/OJG5eZ2f-Pg/EnhN06Ozg2oJ

    # read in Troup catalog
    _troup = np.genfromtxt(TROUP_DATA_PATH,
                           delimiter=",",
                           names=True,
                           dtype=None)

    if index is not None and apogee_id is None:
        apogee_id = _troup['APOGEE_ID'].astype(str)[index]

    OUTPUT_FILENAME = join(OUTPUT_PATH, "troup-{}.hdf5".format(sampler_name))
    if exists(OUTPUT_FILENAME) and not overwrite:
        with h5py.File(OUTPUT_FILENAME) as f:
            if apogee_id in f.groups():
                logger.info("{} has already been modeled - use '--overwrite' "
                            "to re-run MCMC for this target.")

    # load data files -- Troup catalog and full APOGEE allVisit file
    troup = tbl.Table(_troup[_troup['APOGEE_ID'].astype(str) == apogee_id])
    _allvisit = fits.getdata(ALLVISIT_DATA_PATH, 1)
    target = tbl.Table(
        _allvisit[_allvisit['APOGEE_ID'].astype(str) == apogee_id])

    # read data and orbit parameters and produce initial guess for MCMC
    logger.debug("Reading data from Troup catalog and allVisit files...")
    data = allVisit_to_rvdata(target)
    troup_orbit = troup_to_init_orbit(troup, data)
    n_dim = 7  # HACK: magic number

    # first figure is initial guess
    plot_init_orbit(troup_orbit, data, apogee_id)

    # create model object to evaluate prior, likelihood, posterior
    model = OrbitModel(data=data, orbit=troup_orbit.copy())

    # sample initial conditions for walkers
    logger.debug("Generating initial conditions for MCMC walkers...")
    p0 = emcee.utils.sample_ball(model.get_par_vec(),
                                 1E-3 * model.get_par_vec(),
                                 size=n_walkers)

    # special treatment for ln_P
    p0[:, 0] = np.random.normal(np.log(model.orbit._P), 0.5, size=p0.shape[0])

    # special treatment for s
    p0[:,
       6] = np.abs(np.random.normal(0, 1E-3, size=p0.shape[0]) * u.km /
                   u.s).decompose(usys).value

    if sampler_name == 'emcee':
        sampler = emcee.EnsembleSampler(n_walkers,
                                        dim=n_dim,
                                        lnpostfn=model,
                                        pool=pool)

    elif sampler_name == 'kombine':
        # TODO: add option for Prior-sampeld initial conditions, don't assume uniform for kombine
        p0 = np.zeros((n_walkers, n_dim))

        p0[:, 0] = np.random.uniform(1., 8., n_walkers)

        _asini = np.random.uniform(-1., 3., n_walkers)
        _phi0 = np.random.uniform(0, 2 * np.pi, n_walkers)
        p0[:, 1] = _asini * np.cos(_phi0)
        p0[:, 2] = _asini * np.sin(_phi0)

        _ecc = np.random.uniform(0, 1, n_walkers)
        _omega = np.random.uniform(0, 2 * np.pi, n_walkers)
        p0[:, 3] = np.sqrt(_ecc) * np.cos(_omega)
        p0[:, 4] = np.sqrt(_ecc) * np.sin(_omega)

        p0[:, 5] = (np.random.normal(0., 75., n_walkers) * u.km /
                    u.s).decompose(usys).value

        p0[:, 6] = (np.exp(np.random.uniform(-8, 0., n_walkers)) * u.km /
                    u.s).decompose(usys).value

        sampler = kombine.Sampler(n_walkers,
                                  ndim=n_dim,
                                  lnpostfn=model,
                                  pool=pool)

    else:
        raise ValueError("Invalid sampler name '{}'".format(sampler_name))

    # make sure all initial conditions return finite probabilities
    for pp in p0:
        assert np.isfinite(model(pp))

    # burn-in phase
    if n_burnin > 0:
        logger.debug(
            "Burning in the MCMC sampler for {} steps...".format(n_steps))

        _t1 = time.time()
        if sampler_name == 'kombine':
            sampler.burnin(p0)
        else:
            pos, _, _ = sampler.run_mcmc(p0, N=n_steps)
            sampler.reset()
        logger.debug("done with burn-in after {} seconds.".format(time.time() -
                                                                  _t1))

    else:
        pos = p0

    # run the damn sampler!
    _t1 = time.time()

    logger.info("Running MCMC sampler for {} steps...".format(n_steps))
    if sampler_name == 'kombine':
        sampler.run_mcmc(n_steps)
    else:
        sampler.run_mcmc(pos, N=n_steps)

    pool.close()
    logger.info("done sampling after {} seconds.".format(time.time() - _t1))

    if sampler_name == 'kombine':
        # HACK: kombine uses different axes order
        chain = np.swapaxes(sampler.chain, 0, 1)
    else:
        chain = sampler.chain

    # output the chain and metadata to HDF5 file
    with h5py.File(OUTPUT_FILENAME,
                   'a') as f:  # read/write if exists, create otherwise
        if apogee_id in f and overwrite:
            del f[apogee_id]

        elif apogee_id in f and not overwrite:
            # should not get here!!
            raise RuntimeError("How did I get here???")

        g = f.create_group(apogee_id)

        g.create_dataset('p0', data=p0)
        g.create_dataset('chain', data=chain)

        # metadata
        g.attrs['n_walkers'] = n_walkers
        g.attrs['n_steps'] = n_steps
        g.attrs['n_burnin'] = n_burnin

    # plot orbits computed from the samples
    logger.debug("Plotting the MCMC samples...")

    fig, ax = plt.subplots(1, 1, figsize=(8, 6))

    fig = model.plot_rv_samples(chain[:, -1], ax=ax)
    _ = model.data.plot(ax=ax)
    fig.tight_layout()
    fig.savefig(join(PLOT_PATH, "{}-1-rv-curves.png".format(apogee_id)),
                dpi=256)

    # make a corner plot
    flatchain = np.vstack(chain[:, -256:])
    plot_pars = model.vec_to_plot_pars(flatchain)
    troup_vals = [
        np.log(troup_orbit.P.to(u.day).value), troup_orbit.mf.value,
        troup_orbit.ecc,
        troup_orbit.omega.to(u.degree).value, troup_orbit.t0.mjd,
        -troup_orbit.v0.to(u.km / u.s).value, 0.
    ]
    fig = corner.corner(plot_pars, labels=model.plot_labels, truths=troup_vals)
    fig.tight_layout()
    fig.savefig(join(PLOT_PATH, "{}-2-corner.png".format(apogee_id)), dpi=256)
    logger.debug("done!")

    # make MCMC diagnostic plots as well (e.g., acceptance fraction, chain traces)
    plot_mcmc_diagnostics(sampler, p0, model, sampler_name, apogee_id)

    sys.exit(0)
Beispiel #7
0
        """
        return -3.5 * np.log(self._interpolant(X[0], X[1], grid=False))

    def lnpost(self, X):
        return self.lnprior(X) + self.lnlike(X)

    def __call__(self, X):
        return self.lnpost(X)


lnprob = Posterior('../docs/_static/kombine.png')

# Initially distribute the ensemble across the prior
nwalkers = 1000
ndim = 2
sampler = kombine.Sampler(nwalkers, ndim, lnprob)
p = lnprob.prior_draw(nwalkers)

# Sample for a bit
p, prob, q = sampler.run_mcmc(200, p)

if triangle is None:
    print("Get triangle.py for some awesome corner plots!")
    print("https://github.com/dfm/triangle.py")

else:
    triangle.corner(p)
    fig = triangle.corner(p)
    fig.savefig("triangle.png")

if prism is None:
Beispiel #8
0
 def _initialise_sampler(self):
     import kombine
     self._sampler = kombine.Sampler(**self.sampler_init_kwargs)
     self._init_chain_file()
        self.ndim = self.cov[0].shape

    def logpdf(self, x):
        return mvn.logpdf(x, mean=np.zeros(self.ndim), cov=self.cov)

    def __call__(self, x):
        return self.logpdf(x)


A = np.random.rand(ndim, ndim)
cov = A + A.T + ndim * np.eye(ndim)
lnpdf = Target(cov)

start_time = time.perf_counter()
nwalkers = 500
sampler = kombine.Sampler(nwalkers, ndim, lnpdf, processes=1)
p0 = np.random.uniform(-10, 10, size=(nwalkers, ndim))
p, post, q = sampler.burnin(p0)
Nsteps = 1000
total = Nsteps * nwalkers * ndim
p, post, q = sampler.run_mcmc(Nsteps)

time_elapsed = round(time.perf_counter() - start_time, 2)
print("KOMBINE SAMPLER TEST - 2D Gaussian")
print("Time Elapsed: %.2f" % time_elapsed)
print("Average Step/sec: %.3f\n" % float(total / time_elapsed))

fig, [ax1, ax2, ax3] = plt.subplots(1, 3, figsize=(15, 3))
ax1.plot(sampler.acceptance_fraction,
         'k',
         alpha=.5,
Beispiel #10
0
def sample(logpost,
           p0,
           ftol=None,
           view='moll',
           lookback=5,
           epoch_starts=None,
           epoch_duration=np.inf,
           nwalkers=1024,
           nskip=10,
           nsteps=100,
           **kwargs):
    import emcee
    import kombine

    ndim = len(p0[0])
    pbests = [p0[0]]
    p = p0

    if ftol is None:
        ftol = .1 * len(p0)

    #sampler = emcee.EnsembleSampler(nwalkers, ndim, logpost, threads=8)
    sampler = kombine.Sampler(nwalkers, ndim, logpost)

    epoch_starts = [logpost.times[0]] if epoch_starts is None else epoch_starts

    def cb(x, probs):
        pbest = x[np.argmax(probs)]
        clear_output(wait=True)

        pbests.append(x)

        fig1 = plt.figure(num=1)
        axs = fig1.get_axes()
        ax1, ax2 = axs[:len(axs) / 2], axs[len(axs) / 2:]

        fig2 = plt.figure(num=2)
        ax3, ax4 = fig2.get_axes()

        fig3 = plt.figure(num=3)

        for a1, a2 in zip(ax1, ax2):
            a1.cla()
            a2.cla()

        ax3.cla()
        ax4.cla()

        probs = (logpost.loglikelihood(pbest), logpost.log_prior(pbest))
        history.append(pbest)
        lnprobs.append(probs)

        for ax, epoch_start in zip(ax1, epoch_starts):
            sel = (logpost.times >= epoch_start) & (logpost.times - epoch_start
                                                    < epoch_duration)
            ax.errorbar(logpost.times[sel],
                        logpost.reflectance[sel],
                        logpost.sigma_reflectance[sel],
                        color='k',
                        lw=1.5,
                        fmt='o',
                        markersize=0,
                        capthick=0)

        for i in range(0, lookback):
            try:
                p = history[-(i + 1)]
                lc = logpost.lightcurve(p)
                for a1, a2, epoch_start in zip(ax1, ax2, epoch_starts):
                    sel = (logpost.times >= epoch_start) & (
                        logpost.times - epoch_start < epoch_duration)
                    if i == 0:
                        color = 'orange'
                    else:
                        color = 'b'
                    a1.plot(logpost.times[sel],
                            lc[sel],
                            color=color,
                            alpha=1 - i * 1. / lookback)
                    a2.plot(logpost.times[sel],
                            (logpost.reflectance - lc)[sel] /
                            (logpost.error_scale(p) *
                             logpost.sigma_reflectance)[sel],
                            color=color,
                            alpha=1 - i * 1. / lookback)
                    a1.set_xlim(logpost.times[sel].min(),
                                logpost.times[sel].max())
                    a2.set_xlim(logpost.times[sel].min(),
                                logpost.times[sel].max())
            except IndexError:
                pass
        for a1, a2 in zip(ax1, ax2):
            l, h = a2.get_xlim()
            a2.plot((l, h), (0, 0), color='k', ls='--', alpha=0.5)
            a1.set_xlabel('time')
            a2.set_xlabel('time')
        ax1[0].set_ylabel('reflectance')
        ax2[0].set_ylabel('standardized residual')

        low, high = len(history) - min(lookback, len(history)), len(history)
        xs = np.arange(low, high)
        for param in params:
            ax3.plot(xs,
                     [logpost.to_params(p)[param] for p in history[low:high]],
                     label=param)
        ax3.legend(loc='lower left', frameon=False)
        lines = ax4.plot(
            xs[1:],
            np.diff(lnprobs[-min(lookback, len(lnprobs)):], axis=0) /
            lnprobs[-1])
        ax4.legend(lines, ['likelihood', 'prior'],
                   loc='lower left',
                   frameon=False)
        ax3.set_xlabel('steps')
        ax4.set_xlabel('steps')
        ax3.set_ylabel('param')
        ax4.set_ylabel('$\Delta \log$ PDF')

        display(fig1)
        display(fig2)

        fig3.clear()
        projector(logpost.hpmap(pbest), view=this_view, fig=3, **kwargs)
        display(fig3)

        display(probs)
        try:
            prob_diff = np.sum(np.diff(lnprobs[-2:], axis=0) / lnprobs[-1])
        except IndexError:
            prob_diff = np.inf
        return prob_diff

    history = []
    lnprobs = []
    params = logpost.dtype.names

    this_view = view

    fig1, _ = plt.subplots(2, len(epoch_starts), num=1, figsize=(16, 8))
    fig2, _ = plt.subplots(1, 2, num=2, figsize=(16, 4))
    fig3 = plt.figure(num=3)

    try:
        delta_p = np.inf
        while delta_p < 0 or delta_p > ftol or delta_p == 0.:
            #p, prob, _ = sampler.run_mcmc(p, nskip)
            p, prob, _ = sampler.run_mcmc(nskip, p)
            delta_p = cb(p, prob)
    except KeyboardInterrupt:
        return pbests[-1]

    plt.close(fig1)
    plt.close(fig2)
    plt.close(fig3)
    return p
Beispiel #11
0
from scipy.stats import multivariate_normal


class Model(object):
    def __init__(self, mean, cov):
        self.mean = np.atleast_1d(mean)
        self.cov = np.array(cov)
        self.ndim = self.cov.shape[0]

    def lnposterior(self, x):
        return multivariate_normal.logpdf(x, mean=self.mean, cov=self.cov)

    def __call__(self, x):
        return self.lnposterior(x)


ndim = 3
A = np.random.rand(ndim, ndim)
mean = np.zeros(ndim)
cov = A * A.T + ndim * np.eye(ndim)

# create an ND Gaussian model
model = Model(mean, cov)

nwalkers = 500
sampler = kombine.Sampler(nwalkers, ndim, model, mpi=True)

p0 = np.random.uniform(-10, 10, size=(nwalkers, ndim))
p, post, q = sampler.burnin(p0)
p, post, q = sampler.run_mcmc(100)
Beispiel #12
0
        return self.lnposterior(x)


ndim = 3
A = np.random.rand(ndim, ndim)
mean = np.zeros(ndim)
cov = A * A.T + ndim * np.eye(ndim)

# create an ND Gaussian model
model = Model(mean, cov)

# define an MPI pool
pool = MPIPool()

# # Make sure the thread we're running on is the master
if not pool.is_master():
    pool.wait()
    sys.exit(0)

nwalkers = 500
sampler = kombine.Sampler(nwalkers, ndim, model, pool=pool)

p0 = np.random.uniform(-10, 10, size=(nwalkers, ndim))
p, post, q = sampler.burnin(p0)
p, post, q = sampler.run_mcmc(100)

# close the MPI poll
pool.close()

sys.exit(0)
Beispiel #13
0
def bvp_mcmc_single(config_params, chain_filename_ncomp=None):
    """
	Run MCMC and save the chain based on parameters defined 
	in config file. 

	Parameters
	-----------
	config_params: obj
		Parameter object defined by the config file; Following 
		attributes are used in creating walkers initialization.
	chain_filename_ncomp: str
		Output filename without extention; '.npy' is assumed 
		added later

	Returns
	-----------
	chains: python format (.npy) binary file 
		One chain for the specified MCMC run. Use np.load
		to load the n-dim array into memory for manipulation.
	"""
    from bayesvp.likelihood import Posterior

    if chain_filename_ncomp is None:
        chain_filename_ncomp = config_params.chain_fname

    # define the MCMC parameters.
    p0 = _create_walkers_init(config_params)
    ndim = np.shape(p0)[1]

    # Define the natural log of the posterior
    lnprob = Posterior(config_params)

    if config_params.mcmc_sampler.lower() == 'emcee':
        import emcee
        sampler = emcee.EnsembleSampler(config_params.nwalkers,
                                        ndim,
                                        lnprob,
                                        threads=config_params.nthreads)
        sampler.run_mcmc(p0, config_params.nsteps)
        np.save(chain_filename_ncomp + '.npy',
                np.swapaxes(sampler.chain, 0, 1))

    elif config_params.mcmc_sampler.lower() == 'kombine':
        import kombine
        sampler = kombine.Sampler(config_params.nwalkers,
                                  ndim,
                                  lnprob,
                                  processes=config_params.nthreads)

        # First do a rough burn in based on accetance rate.
        p_post_q = sampler.burnin(p0)
        p_post_q = sampler.run_mcmc(config_params.nsteps)
        np.save(chain_filename_ncomp + '.npy', sampler.chain)

    else:
        sys.exit('Error! No MCMC sampler selected.\nExiting program...')

    # Compute Gelman-Rubin Indicator
    dnsteps = int(config_params.nsteps * 0.05)
    n_steps = []
    Rgrs = []
    for n in xrange(dnsteps, config_params.nsteps):
        if n % dnsteps == 0:
            Rgrs.append(gr_indicator(sampler.chain[:n, :, :]))
            n_steps.append(n)
    n_steps = np.array(n_steps)
    Rgrs = np.array(Rgrs)

    np.savetxt(chain_filename_ncomp + '_GR.dat',
               np.c_[n_steps, Rgrs],
               fmt='%.5f',
               header='col1=steps\tcoln=gr_indicator')

    return
Beispiel #14
0
def main(name, nstars, nburn, nwalk, mpi=False, save=False):
    pool = get_pool(mpi=mpi)

    np.random.seed(42)

    # parameters
    true_sat_mass = 2.5E5
    nstars = 8
    dt = -0.1

    # read in the SCF simulation data

    s = scf.SCFReader(os.path.join(scfpath, "simulations/runs/spherical/"))

    tbl = s.last_snap(units=galactic)
    total_time = tbl.meta['time']
    nsteps = abs(int(total_time / dt))

    stream_tbl = tbl[(tbl["tub"] != 0)]
    prog_w = np.median(scf.tbl_to_w(tbl[(tbl["tub"] == 0)]), axis=0)

    # pluck out a certain number of stars...
    ixs = []
    while len(ixs) < nstars:
        ixs = np.unique(np.random.randint(len(stream_tbl), size=nstars))
    data_w = scf.tbl_to_w(stream_tbl[ixs])

    prog_E = true_potential.total_energy(prog_w[:3], prog_w[3:])
    dE = (true_potential.total_energy(data_w[:, :3], data_w[:, 3:]) -
          prog_E) / prog_E
    betas = -2 * (dE > 0).astype(int) + 1.

    # test true potential
    # vals = np.linspace(0.1, 3.5, 32)
    # lls = []
    # for val in vals:
    #     p = [val, true_params['v_c'], np.log(20.)]
    #     ll = ln_posterior_nfw(p, dt, nsteps, prog_w, data_w, betas, true_sat_mass)
    #     lls.append(ll)

    # plt.figure()
    # plt.plot(vals, lls)
    # plt.show()

    # return

    # ----------------------------
    # Emcee

    if name == 'nfw':
        # NFW
        print("Firing up sampler for NFW")
        ndim = 3
        nwalkers = 128 * ndim
        sampler = kombine.Sampler(nwalkers=nwalkers,
                                  ndim=ndim,
                                  lnpostfn=ln_posterior_nfw,
                                  args=(dt, nsteps, prog_w, data_w, betas,
                                        true_sat_mass),
                                  pool=pool)

        p0 = np.zeros((nwalkers, ndim))
        p0[:, 0] = np.random.normal(1., 0.05, size=nwalkers)  # alpha
        p0[:, 1] = np.random.normal(0.2, 0.02, size=nwalkers)  # v_c
        p0[:, 2] = np.random.normal(3., 0.1, size=nwalkers)  # log_r_s

        sampler = sample_dat_ish(sampler, p0, nburn=nburn, nwalk=nwalk)
        pool.close()

        if save:
            np.save(
                "/vega/astro/users/amp2217/projects/streams/output/michigan_hack/nfw_chain.npy",
                sampler.chain)
            np.save(
                "/vega/astro/users/amp2217/projects/streams/output/michigan_hack/nfw_lnprob.npy",
                sampler.lnprobability)

    elif name == "hernq":
        # Hernquist
        print("Firing up sampler for Hernquist")
        ndim = 3
        nwalkers = 128 * ndim
        sampler = kombine.Sampler(nwalkers=nwalkers,
                                  ndim=ndim,
                                  lnpostfn=ln_posterior_hernq,
                                  args=(dt, nsteps, prog_w, data_w, betas,
                                        true_sat_mass),
                                  pool=pool)

        p0 = np.zeros((nwalkers, ndim))
        p0[:, 0] = np.random.normal(1., 0.05, size=nwalkers)  # alpha
        p0[:, 1] = np.random.normal(26., 0.15, size=nwalkers)  # log_m
        p0[:, 2] = np.random.normal(20., 0.5, size=nwalkers)  # c

        sampler = sample_dat_ish(sampler, p0, nburn=nburn, nwalk=nwalk)
        pool.close()

        if save:
            np.save(
                "/vega/astro/users/amp2217/projects/streams/output/michigan_hack/hernq_chain.npy",
                sampler.chain)
            np.save(
                "/vega/astro/users/amp2217/projects/streams/output/michigan_hack/hernq_lnprob.npy",
                sampler.lnprobability)

    elif name == "bfe":
        # BFE
        print("Firing up sampler for BFE")
        ndim = 7
        nwalkers = 128 * ndim
        sampler = kombine.Sampler(nwalkers=nwalkers,
                                  ndim=ndim,
                                  lnpostfn=ln_posterior_bfe,
                                  args=(dt, nsteps, prog_w, data_w, betas,
                                        true_sat_mass),
                                  pool=pool)

        p0 = np.zeros((nwalkers, ndim))
        p0[:, 0] = np.random.normal(1., 0.05, size=nwalkers)  # alpha
        p0[:, 1] = np.random.normal(26., 0.15, size=nwalkers)  # log_m
        p0[:, 2] = np.random.normal(20., 0.5, size=nwalkers)  # c
        p0[:, 3] = np.random.uniform(0.9, 1.1, size=nwalkers)  # c1
        p0[:, 4] = np.random.uniform(0., 0.1, size=nwalkers)  # c2
        p0[:, 5] = np.random.uniform(0., 0.05, size=nwalkers)  # c3
        p0[:, 6] = np.random.uniform(0., 0.02, size=nwalkers)  # c4
        # p0[:,7] = np.random.uniform(-0., -0.02, size=nwalkers) # c5
        # p0[:,8] = np.random.uniform(-0., -0.02, size=nwalkers) # c6

        sampler = sample_dat_ish(sampler, p0, nburn=nburn, nwalk=nwalk)
        pool.close()

        if save:
            np.save(
                "/vega/astro/users/amp2217/projects/streams/output/michigan_hack/bfe_chain.npy",
                sampler.chain)
            np.save(
                "/vega/astro/users/amp2217/projects/streams/output/michigan_hack/bfe_lnprob.npy",
                sampler.lnprobability)
    sys.exit(0)