Exemplo n.º 1
0
def mcmc(post, nwalkers=50, nrun=10000, ensembles=8, checkinterval=50, burnGR=1.03, maxGR=1.01,
         minTz=1000, minsteps=1000, thin=1, serial=False):
    """Run MCMC
    Run MCMC chains using the emcee EnsambleSampler
    Args:
        post (radvel.posterior): radvel posterior object
        nwalkers (int): (optional) number of MCMC walkers
        nrun (int): (optional) number of steps to take
        ensembles (int): (optional) number of ensembles to run. Will be run
            in parallel on separate CPUs
        checkinterval (int): (optional) check MCMC convergence statistics every
            `checkinterval` steps
        burnGR (float): (optional) Maximum G-R statistic to stop burn-in period
        maxGR (float): (optional) Maximum G-R statistic for chains to be deemed well-mixed and halt the MCMC run
        minTz (int): (optional) Minimum Tz to consider well-mixed
        minsteps (int): (optional) Minimum number of steps per walker before convergence tests are performed
        thin (int): (optional) save one sample every N steps (default=1, save every sample)
        serial (bool): set to true if MCMC should be run in serial
    Returns:
        DataFrame: DataFrame containing the MCMC samples
    """

    # check if one or more likelihoods are GPs
    if isinstance(post.likelihood, radvel.likelihood.CompositeLikelihood):
        check_gp = [like for like in post.likelihood.like_list if isinstance(like, radvel.likelihood.GPLikelihood)]
    else:
        check_gp = isinstance(post.likelihood, radvel.likelihood.GPLikelihood)  

    np_info = np.__config__.blas_opt_info
    if 'extra_link_args' in np_info.keys() \
       and check_gp \
       and ('-Wl,Accelerate' in np_info['extra_link_args']) \
       and serial == False:
        print("WARNING: Parallel processing with Gaussian Processes will not work with your current"
                      + " numpy installation. See radvel.readthedocs.io/en/latest/OSX-multiprocessing.html"
                      + " for more details. Running in serial with " + str(ensembles) + " ensembles.")
        serial = True

    statevars.ensembles = ensembles
    statevars.nwalkers = nwalkers
    statevars.checkinterval = checkinterval
    
    nrun = int(nrun)
        
    # Get an initial array value
    pi = post.get_vary_params()
    statevars.ndim = pi.size

    if nwalkers < 2*statevars.ndim:
        print("WARNING: Number of walkers is less than 2 times number \
of free parameters. Adjusting number of walkers to {}".format(2*statevars.ndim))
        statevars.nwalkers = 2*statevars.ndim

    # set up perturbation size
    pscales = []
    for par in post.list_vary_params():
        val = post.params[par].value
        if post.params[par].mcmcscale is None:
            if par.startswith('per'):
                pscale = np.abs(val * 1e-5*np.log10(val))
            elif par.startswith('logper'):
                pscale = np.abs(1e-5 * val)
            elif par.startswith('tc'):
                pscale = 0.1
            else:
                pscale = np.abs(0.10 * val)
            post.params[par].mcmc_scale = pscale
        else:
            pscale = post.params[par].mcmcscale
        pscales.append(pscale)
    pscales = np.array(pscales)

    statevars.samplers = []
    statevars.initial_positions = []
    for e in range(ensembles):
        pi = post.get_vary_params()
        p0 = np.vstack([pi]*statevars.nwalkers)
        p0 += [np.random.rand(statevars.ndim)*pscales for i in range(statevars.nwalkers)]
        statevars.initial_positions.append(p0)
        statevars.samplers.append(emcee.EnsembleSampler( 
            statevars.nwalkers, statevars.ndim, post.logprob_array, threads=1))

    num_run = int(np.round(nrun / checkinterval))
    statevars.totsteps = nrun*statevars.nwalkers*statevars.ensembles
    statevars.mixcount = 0
    statevars.ismixed = 0
    statevars.burn_complete = False
    statevars.nburn = 0
    statevars.ncomplete = statevars.nburn
    statevars.pcomplete = 0
    statevars.rate = 0
    statevars.ar = 0
    statevars.mintz = -1
    statevars.maxgr = np.inf
    statevars.t0 = time.time()

    for r in range(num_run):
        t1 = time.time()
        mcmc_input_array = []
        for i, sampler in enumerate(statevars.samplers):
            if sampler.flatlnprobability.shape[0] == 0:
                p1 = statevars.initial_positions[i]
            else:
                p1 = None
            mcmc_input = (sampler, p1, checkinterval)
            mcmc_input_array.append(mcmc_input)

        if serial:
            statevars.samplers = []
            for i in range(ensembles):
                result = _domcmc(mcmc_input_array[i])
                statevars.samplers.append(result)
        else:
            pool = mp.Pool(statevars.ensembles)
            statevars.samplers = pool.map(_domcmc, mcmc_input_array)
            pool.close()  # terminates worker processes once all work is done
            pool.join()   # waits for all processes to finish before proceeding

        t2 = time.time()
        statevars.interval = t2 - t1

        convergence_check(statevars.samplers, maxGR=maxGR, minTz=minTz, minsteps=minsteps)

        # Burn-in complete after maximum G-R statistic first reaches burnGR
        # reset samplers
        if not statevars.burn_complete and statevars.maxgr <= burnGR:
            for i, sampler in enumerate(statevars.samplers):
                statevars.initial_positions[i] = sampler._last_run_mcmc_result[0]
                sampler.reset()
                statevars.samplers[i] = sampler
            msg = (
                "\nDiscarding burn-in now that the chains are marginally "
                "well-mixed\n"
            )
            print(msg)
            statevars.nburn = statevars.ncomplete
            statevars.burn_complete = True

        if statevars.mixcount >= 5:
            tf = time.time()
            tdiff = tf - statevars.t0
            tdiff,units = utils.time_print(tdiff)
            msg = (
                "\nChains are well-mixed after {:d} steps! MCMC completed in "
                "{:3.1f} {:s}"
            ).format(statevars.ncomplete, tdiff, units)
            print(msg)
            break

    print("\n")
    if statevars.ismixed and statevars.mixcount < 5: 
        msg = (
            "MCMC: WARNING: chains did not pass 5 consecutive convergence "
            "tests. They may be marginally well=mixed."
        )
        print(msg)
    elif not statevars.ismixed: 
        msg = (
            "MCMC: WARNING: chains did not pass convergence tests. They are "
            "likely not well-mixed."
        )
        print(msg)
        
    df = pd.DataFrame(
        statevars.tchains.reshape(statevars.ndim,statevars.tchains.shape[1]*statevars.tchains.shape[2]).transpose(),
        columns=post.list_vary_params())
    df['lnprobability'] = np.hstack(statevars.lnprob)

    df = df.iloc[::thin]

    return df
Exemplo n.º 2
0
def mcmc(post, nwalkers=50, nrun=10000, ensembles=8, checkinterval=50):
    """Run MCMC
    Run MCMC chains using the emcee EnsambleSampler
    Args:
        post (radvel.posterior): radvel posterior object
        nwalkers (int): number of MCMC walkers
        nrun (int): number of steps to take
        ensembles (int): number of ensembles to run. Will be run
            in parallel on separate CPUs
        checkinterval (int): check MCMC convergence statistics every 
            `checkinterval` steps
    Returns:
        DataFrame: DataFrame containing the MCMC samples
    """

    # server = pp.Server(ncpus=ensembles)

    # statevars.server = server
    statevars.ensembles = ensembles
    statevars.nwalkers = nwalkers
    statevars.checkinterval = checkinterval

    nrun = int(nrun)

    # Get an initial array value
    pi = post.get_vary_params()
    statevars.ndim = pi.size

    if nwalkers < 2 * statevars.ndim:
        print("WARNING: Number of walkers is less than 2 times number \
of free parameters. Adjusting number of walkers to {}".format(2 *
                                                              statevars.ndim))
        statevars.nwalkers = 2 * statevars.ndim

    # set up perturbation size
    pscales = []
    for par in post.list_vary_params():
        val = post.params[par].value
        if post.params[par].mcmcscale is None:
            if par.startswith('per'):
                pscale = np.abs(val * 1e-5 * np.log10(val))
            #    pscale_per = pscale
            elif par.startswith('logper'):
                pscale = np.abs(1e-5 * val)
            #    pscale_per = pscale
            elif par.startswith('tc'):
                pscale = 0.1
            else:
                pscale = np.abs(0.10 * val)
            post.params[par].mcmc_scale = pscale
        else:
            pscale = post.params[par].mcmcscale
        pscales.append(pscale)
    pscales = np.array(pscales)

    statevars.samplers = []
    statevars.initial_positions = []
    for e in range(ensembles):
        lcopy = copy.deepcopy(post)
        pi = lcopy.get_vary_params()
        p0 = np.vstack([pi] * statevars.nwalkers)
        p0 += [
            np.random.rand(statevars.ndim) * pscales
            for i in range(statevars.nwalkers)
        ]
        statevars.initial_positions.append(p0)
        statevars.samplers.append(
            emcee.EnsembleSampler(statevars.nwalkers,
                                  statevars.ndim,
                                  lcopy.logprob_array,
                                  threads=1))

    num_run = int(np.round(nrun / checkinterval))
    statevars.totsteps = nrun * statevars.nwalkers * statevars.ensembles
    statevars.mixcount = 0
    statevars.ismixed = 0
    statevars.burn_complete = False
    statevars.nburn = 0
    statevars.ncomplete = statevars.nburn
    statevars.pcomplete = 0
    statevars.rate = 0
    statevars.ar = 0
    statevars.mintz = -1
    statevars.maxgr = np.inf
    statevars.t0 = time.time()

    for r in range(num_run):
        t1 = time.time()
        mcmc_input_array = []
        for i, sampler in enumerate(statevars.samplers):
            if sampler.flatlnprobability.shape[0] == 0:
                p1 = statevars.initial_positions[i]
            else:
                p1 = None
            mcmc_input = (sampler, p1, checkinterval)
            mcmc_input_array.append(mcmc_input)

        pool = Pool(statevars.ensembles)
        statevars.samplers = pool.map(_domcmc, mcmc_input_array)
        pool.close()  #terminates worker processes once all work is done
        pool.join()  #waits for all processes to finish before proceeding

        t2 = time.time()
        statevars.interval = t2 - t1

        convergence_check(statevars.samplers)

        # Burn-in complete after maximum G-R statistic first reaches burnGR
        # reset samplers
        if not statevars.burn_complete and statevars.maxgr <= burnGR:
            for i, sampler in enumerate(statevars.samplers):
                statevars.initial_positions[i] = sampler._last_run_mcmc_result[
                    0]
                sampler.reset()
                statevars.samplers[i] = sampler
            msg = ("\nDiscarding burn-in now that the chains are marginally "
                   "well-mixed\n")
            print(msg)
            statevars.nburn = statevars.ncomplete
            statevars.burn_complete = True

        if statevars.mixcount >= 5:
            tf = time.time()
            tdiff = tf - statevars.t0
            tdiff, units = utils.time_print(tdiff)
            msg = (
                "\nChains are well-mixed after {:d} steps! MCMC completed in "
                "{:3.1f} {:s}").format(statevars.ncomplete, tdiff, units)
            print(msg)
            break

    print("\n")
    if statevars.ismixed and statevars.mixcount < 5:
        msg = ("MCMC: WARNING: chains did not pass 5 consecutive convergence "
               "tests. They may be marginally well=mixed.")
        print(msg)
    elif not statevars.ismixed:
        msg = (
            "MCMC: WARNING: chains did not pass convergence tests. They are "
            "likely not well-mixed.")
        print(msg)

    df = pd.DataFrame(statevars.tchains.reshape(
        statevars.ndim,
        statevars.tchains.shape[1] * statevars.tchains.shape[2]).transpose(),
                      columns=post.list_vary_params())
    df['lnprobability'] = np.hstack(statevars.lnprob)

    return df
Exemplo n.º 3
0
def mcmc(likelihood, nwalkers=50, nrun=10000, threads=1, checkinterval=50):
    """Run MCMC

    Run MCMC chains using the emcee EnsambleSampler

    Args:
        likelihood (radvel.likelihood): radvel likelihood object
        nwalkers (int): number of MCMC walkers
        nrun (int): number of steps to take
        threads (int): number of CPU threads to utilize
        checkinterval (int): check MCMC convergence statistics every `checkinterval` steps

    Returns:
        DataFrame: DataFrame containing the MCMC samples

    """
    # Get an initial array value
    p0 = likelihood.get_vary_params()
    ndim = p0.size
    p0 = np.vstack([p0] * nwalkers)
    p0 += [np.random.rand(ndim) * 0.03 for i in range(nwalkers)]
    sampler = emcee.EnsembleSampler(nwalkers,
                                    ndim,
                                    likelihood.logprob_array,
                                    threads=threads)

    pos = p0
    num_run = int(np.round(nrun / checkinterval))
    totsteps = nrun * nwalkers
    mixcount = 0
    burn_complete = False
    t0 = time.time()
    for r in range(num_run):
        t1 = time.time()
        pos, prob, state = sampler.run_mcmc(pos, checkinterval)
        t2 = time.time()

        rate = (checkinterval * nwalkers) / (t2 - t1)
        ncomplete = sampler.flatlnprobability.shape[0]
        pcomplete = ncomplete / float(totsteps) * 100
        ar = sampler.acceptance_fraction.mean() * 100.

        tchains = sampler.chain.transpose()

        (ismixed, gr, tz) = gelman_rubin(tchains)
        mintz = min(tz)
        maxgr = max(gr)
        if ismixed: mixcount += 1
        else: mixcount = 0

        # Burn-in complete after maximum G-R statistic first reaches 1.10
        # reset sampler
        if not burn_complete and maxgr <= burnGR:
            sampler.reset()
            print "\nDiscarding burn-in now that the chains are marginally well-mixed\n"
            burn_complete = True

        if mixcount >= 5:
            tf = time.time()
            tdiff = tf - t0
            tdiff, units = utils.time_print(tdiff)
            print "\nChains are well-mixed after %d steps! MCMC completed in %3.1f %s" % (
                ncomplete, tdiff, units)
            break
        else:
            sys.stdout.write(
                "%d/%d (%3.1f%%) steps complete; Running %.2f steps/s; Mean acceptance rate = %3.1f%%; Min Tz = %.1f; Max G-R = %4.2f      \r"
                % (ncomplete, totsteps, pcomplete, rate, ar, mintz, maxgr))
            sys.stdout.flush()

    print "\n"
    if ismixed and mixcount < 5:
        print "MCMC: WARNING: chains did not pass 5 consecutive convergence tests. They may be marginally well=mixed."
    elif not ismixed:
        print "MCMC: WARNING: chains did not pass convergence tests. They are likely not well-mixed."

    df = pd.DataFrame(sampler.flatchain, columns=likelihood.list_vary_params())
    df['lnprobability'] = sampler.flatlnprobability
    return df
Exemplo n.º 4
0
def mcmc(post,
         nwalkers=50,
         nrun=10000,
         ensembles=8,
         checkinterval=50,
         minAfactor=40,
         maxArchange=.03,
         burnAfactor=25,
         burnGR=1.03,
         maxGR=1.01,
         minTz=1000,
         minsteps=1000,
         minpercent=5,
         thin=1,
         serial=False,
         save=False,
         savename=None,
         proceed=False,
         proceedname=None):
    """Run MCMC
    Run MCMC chains using the emcee EnsambleSampler
    Args:
        post (radvel.posterior): radvel posterior object
        nwalkers (int): (optional) number of MCMC walkers
        nrun (int): (optional) number of steps to take
        ensembles (int): (optional) number of ensembles to run. Will be run
            in parallel on separate CPUs
        checkinterval (int): (optional) check MCMC convergence statistics every
            `checkinterval` steps
        minAfactor (float): Minimum autocorrelation time factor to deem chains as well-mixed and halt the MCMC run
        maxArchange (float): Maximum relative change in autocorrelation time to deem chains and well-mixed
        burnAfactor (float): Minimum autocorrelation time factor to stop burn-in period. Burn-in ends once burnGr
            or burnAfactor are reached.
        burnGR (float): (optional) Maximum G-R statistic to stop burn-in period. Burn-in ends once burnGr or
            burnAfactor are reached.
        maxGR (float): (optional) Maximum G-R statistic for chains to be deemed well-mixed and halt the MCMC run
        minTz (int): (optional) Minimum Tz to consider well-mixed
        minsteps (int): Minimum number of steps per walker before convergence tests are performed. Convergence checks
            will start after the minsteps threshold or the minpercent threshold has been hit.
        minpercent (float): Minimum percentage of total steps before convergence tests are performed. Convergence checks
            will start after the minsteps threshold or the minpercent threshold has been hit.
        thin (int): (optional) save one sample every N steps (default=1, save every sample)
        serial (bool): set to true if MCMC should be run in serial
        save (bool): set to true to save MCMC chains that can be continued in a future run
        savename (string): location of h5py file where MCMC chains will be saved for future use
        proceed (bool): set to true to continue a previously saved run
        proceedname (string): location of h5py file with previously MCMC run chains
    Returns:
        DataFrame: DataFrame containing the MCMC samples
    """

    statevars.reset()

    try:
        if save and savename is None:
            raise ValueError('save set to true but no savename provided')

        if save:
            h5f = h5py.File(savename, 'a')

        if proceed:
            if proceedname is None:
                raise ValueError(
                    'proceed set to true but no proceedname provided')
            else:
                h5p = h5py.File(savename, 'r')
                msg = 'Loading chains and run information from previous MCMC'
                print(msg)
            statevars.prechains = []
            statevars.prelog_probs = []
            statevars.preaccepted = []
            statevars.preburned = h5p['burned'][0]
            statevars.minafactor = h5p['crit'][0]
            statevars.maxarchange = h5p['crit'][1]
            statevars.mintz = h5p['crit'][2]
            statevars.maxgr = h5p['crit'][3]
            statevars.autosamples = list(h5p['autosample'])
            statevars.automin = list(h5p['automin'])
            statevars.automean = list(h5p['automean'])
            statevars.automax = list(h5p['automax'])
            for i in range(0, int((len(h5p.keys()) - 6) / 3)):
                str_chain = str(i) + '_chain'
                str_log_prob = str(i) + '_log_prob'
                str_accepted = str(i) + '_accepted'
                statevars.prechains.append(h5p[str_chain])
                statevars.prelog_probs.append(h5p[str_log_prob])
                statevars.preaccepted.append(h5p[str_accepted])

        # check if one or more likelihoods are GPs
        if isinstance(post.likelihood, radvel.likelihood.CompositeLikelihood):
            check_gp = [
                like for like in post.likelihood.like_list
                if isinstance(like, radvel.likelihood.GPLikelihood)
            ]
        else:
            check_gp = isinstance(post.likelihood,
                                  radvel.likelihood.GPLikelihood)

        np_info = np.__config__.blas_opt_info
        if 'extra_link_args' in np_info.keys() \
         and check_gp \
         and ('-Wl,Accelerate' in np_info['extra_link_args']) \
         and serial == False:
            print(
                "WARNING: Parallel processing with Gaussian Processes will not work with your current"
                +
                " numpy installation. See radvel.readthedocs.io/en/latest/OSX-multiprocessing.html"
                + " for more details. Running in serial with " +
                str(ensembles) + " ensembles.")
            serial = True

        statevars.ensembles = ensembles
        statevars.nwalkers = nwalkers
        statevars.checkinterval = checkinterval - 1

        nrun = int(nrun)

        # Get an initial array value
        pi = post.get_vary_params()
        statevars.ndim = pi.size

        if nwalkers < 2 * statevars.ndim:
            print(
                "WARNING: Number of walkers is less than 2 times number of free parameters. "
                +
                "Adjusting number of walkers to {}".format(2 * statevars.ndim))
            statevars.nwalkers = 2 * statevars.ndim

        if proceed:
            if len(h5p.keys()) != (3 * statevars.ensembles + 6) or h5p['0_chain'].shape[2] != statevars.ndim \
               or h5p['0_chain'].shape[1] != statevars.nwalkers:
                raise ValueError(
                    'nensembles, nwalkers, and the number of ' +
                    'parameters must be equal to those from previous run.')

        # set up perturbation size

        pscales = []
        names = post.name_vary_params()
        for i, par in enumerate(post.vary_params):
            val = post.vector.vector[par][0]
            if post.vector.vector[par][2] == 0:
                if names[i].startswith('per'):
                    pscale = np.abs(val * 1e-5 * np.log10(val))
                elif names[i].startswith('logper'):
                    pscale = np.abs(1e-5 * val)
                elif names[i].startswith('tc'):
                    pscale = 0.1
                elif val == 0:
                    pscale = .00001
                else:
                    pscale = np.abs(0.10 * val)
                post.vector.vector[par][2] = pscale
            else:
                pscale = post.vector.vector[par][2]
            pscales.append(pscale)
        pscales = np.array(pscales)

        statevars.samplers = []
        statevars.samples = []
        statevars.initial_positions = []
        for e in range(ensembles):
            pi = post.get_vary_params()
            p0 = np.vstack([pi] * statevars.nwalkers)
            p0 += [
                np.random.rand(statevars.ndim) * pscales
                for i in range(statevars.nwalkers)
            ]
            if not proceed:
                statevars.initial_positions.append(p0)
            else:
                statevars.initial_positions.append(
                    statevars.prechains[i][-1, :, :])
            statevars.samplers.append(
                emcee.EnsembleSampler(statevars.nwalkers,
                                      statevars.ndim,
                                      post.logprob_array,
                                      threads=1))

        if proceed:
            for i, sampler in enumerate(statevars.samplers):
                sampler.backend.grow(statevars.prechains[i].shape[0], None)
                sampler.backend.chain = statevars.prechains[i]
                sampler.backend.log_prob = statevars.prelog_probs[i]
                sampler.backend.accepted = statevars.preaccepted[i]
                sampler.backend.iteration = statevars.prechains[i].shape[0]

        num_run = int(np.round(nrun / (checkinterval - 1)))
        statevars.totsteps = nrun * statevars.nwalkers * statevars.ensembles
        statevars.mixcount = 0
        statevars.ismixed = 0
        if proceed and statevars.preburned != 0:
            statevars.burn_complete = True
            statevars.nburn = statevars.preburned
        else:
            statevars.burn_complete = False
            statevars.nburn = 0
        statevars.ncomplete = statevars.nburn
        statevars.pcomplete = 0
        statevars.rate = 0
        statevars.ar = 0
        statevars.minAfactor = -1
        statevars.maxArchange = np.inf
        statevars.mintz = -1
        statevars.maxgr = np.inf
        statevars.t0 = time.time()

        for r in range(num_run):
            t1 = time.time()
            mcmc_input_array = []
            for i, sampler in enumerate(statevars.samplers):
                if sampler.iteration <= 1 or statevars.proceed_started == 0:
                    p1 = statevars.initial_positions[i]
                    statevars.proceed_started = 1
                else:
                    p1 = sampler.get_last_sample()
                for sample in sampler.sample(p1, store=True):
                    mcmc_input = (sampler, p1, (checkinterval - 1))
                    mcmc_input_array.append(mcmc_input)

            if serial:
                statevars.samplers = []
                for i in range(ensembles):
                    result = _domcmc(mcmc_input_array[i])
                    statevars.samplers.append(result)
            else:
                pool = mp.Pool(statevars.ensembles)
                statevars.samplers = pool.map(_domcmc, mcmc_input_array)
                pool.close(
                )  # terminates worker processes once all work is done
                pool.join(
                )  # waits for all processes to finish before proceeding

            t2 = time.time()
            statevars.interval = t2 - t1

            convergence_check(minAfactor=minAfactor,
                              maxArchange=maxArchange,
                              maxGR=maxGR,
                              minTz=minTz,
                              minsteps=minsteps,
                              minpercent=minpercent)

            if save:
                for i, sampler in enumerate(statevars.samplers):
                    str_chain = str(i) + '_chain'
                    str_log_prob = str(i) + '_log_prob'
                    str_accepted = str(i) + '_accepted'
                    if str_chain in h5f.keys():
                        del h5f[str_chain]
                    if str_log_prob in h5f.keys():
                        del h5f[str_log_prob]
                    if str_accepted in h5f.keys():
                        del h5f[str_accepted]
                    if 'crit' in h5f.keys():
                        del h5f['crit']
                    if 'autosample' in h5f.keys():
                        del h5f['autosample']
                    if 'automin' in h5f.keys():
                        del h5f['automin']
                    if 'automean' in h5f.keys():
                        del h5f['automean']
                    if 'automax' in h5f.keys():
                        del h5f['automax']
                    if 'burned' in h5f.keys():
                        del h5f['burned']
                    h5f.create_dataset(str_chain, data=sampler.get_chain())
                    h5f.create_dataset(str_log_prob,
                                       data=sampler.get_log_prob())
                    h5f.create_dataset(str_accepted,
                                       data=sampler.backend.accepted)
                    h5f.create_dataset('crit',
                                       data=[
                                           statevars.minafactor,
                                           statevars.maxarchange,
                                           statevars.mintz, statevars.maxgr
                                       ])
                    h5f.create_dataset('autosample',
                                       data=statevars.autosamples)
                    h5f.create_dataset('automin', data=statevars.automin)
                    h5f.create_dataset('automean', data=statevars.automean)
                    h5f.create_dataset('automax', data=statevars.automax)
                    if statevars.burn_complete == True:
                        h5f.create_dataset('burned', data=[statevars.nburn])
                    else:
                        h5f.create_dataset('burned', data=[0])

            # Burn-in complete after maximum G-R statistic first reaches burnGR or minAfactor reaches burnAfactor
            # reset samplers
            if not statevars.burn_complete and (
                    statevars.maxgr <= burnGR
                    or burnAfactor <= statevars.minafactor):
                for i, sampler in enumerate(statevars.samplers):
                    statevars.initial_positions[i] = sampler.get_last_sample()
                    sampler.reset()
                    statevars.samplers[i] = sampler
                msg = (
                    "\nDiscarding burn-in now that the chains are marginally "
                    "well-mixed\n")
                print(msg)
                statevars.nburn = statevars.ncomplete
                statevars.burn_complete = True

            if statevars.mixcount >= 2:
                tf = time.time()
                tdiff = tf - statevars.t0
                tdiff, units = utils.time_print(tdiff)
                msg = (
                    "\nChains are well-mixed after {:d} steps! MCMC completed in "
                    "{:3.1f} {:s}").format(statevars.ncomplete, tdiff, units)
                _closescr()
                print(msg)
                break

        print("\n")
        if statevars.ismixed and statevars.mixcount < 2:
            msg = (
                "MCMC: WARNING: chains did not pass 2 consecutive convergence "
                "tests. They may be marginally well=mixed.")
            _closescr()
            print(msg)
        elif not statevars.ismixed:
            msg = (
                "MCMC: WARNING: chains did not pass convergence tests. They are "
                "likely not well-mixed.")
            _closescr()
            print(msg)

        preshaped_chain = np.dstack(statevars.chains)
        df = pd.DataFrame(preshaped_chain.reshape(
            preshaped_chain.shape[0],
            preshaped_chain.shape[1] * preshaped_chain.shape[2]).transpose(),
                          columns=post.name_vary_params())
        preshaped_ln = np.hstack(statevars.lnprob)
        df['lnprobability'] = preshaped_ln.reshape(preshaped_chain.shape[1] *
                                                   preshaped_chain.shape[2])
        df = df.iloc[::thin]

        statevars.factor = [minAfactor] * len(statevars.autosamples)

        return df

    except KeyboardInterrupt:
        curses.endwin()
Exemplo n.º 5
0
def mcmc(likelihood, nwalkers=50, nrun=10000, ensembles=8, checkinterval=50):
    """Run MCMC

    Run MCMC chains using the emcee EnsambleSampler

    Args:
        likelihood (radvel.likelihood): radvel likelihood object
        nwalkers (int): number of MCMC walkers
        nrun (int): number of steps to take
        ensembles (int): number of ensembles to run. Will be run
            in parallel on separate CPUs
        checkinterval (int): check MCMC convergence statistics every 
            `checkinterval` steps

    Returns:
        DataFrame: DataFrame containing the MCMC samples

    """
    def _crunch(sampler, ipos, checkinterval):
        sampler.run_mcmc(ipos, checkinterval)
        return sampler

    server = pp.Server(ncpus=ensembles)
    # pool = Pool(processes=1)

    statevars.server = server
    statevars.ensembles = ensembles
    statevars.nwalkers = nwalkers
    statevars.checkinterval = checkinterval
    
    nrun = int(nrun)
        
    # Get an initial array value
    pi = likelihood.get_vary_params()
    statevars.ndim = pi.size

    if nwalkers < 2 * statevars.ndim:
        print("WARNING: Number of walkers is less than 2 times number \
of free parameters. Adjusting number of walkers to {}".format(2*statevars.ndim))
        statevars.nwalkers = 2*statevars.ndim

    # set up perturbation size
    pscales = []
    for par in likelihood.list_vary_params():
        val = likelihood.params[par]
        if par.startswith('per'):
            pscale = np.abs(val * 1e-5*np.log10(val))
            pscale_per = pscale
        elif par.startswith('tc'):
            pscale = 0.1
        else:
            pscale = np.abs(0.10 * val)

        pscales.append(pscale)
        
    pscales = np.array(pscales)

    statevars.samplers = []
    statevars.initial_positions = []
    for e in range(ensembles):
        lcopy = copy.deepcopy(likelihood)
        pi = lcopy.get_vary_params()
        p0 = np.vstack([pi]*statevars.nwalkers)
        p0 += [np.random.rand(statevars.ndim)*pscales for i in range(statevars.nwalkers)]
        statevars.initial_positions.append(p0)
        statevars.samplers.append(emcee.EnsembleSampler( 
            statevars.nwalkers, statevars.ndim, lcopy.logprob_array, threads=1))

    num_run = int(np.round(nrun / checkinterval))
    statevars.totsteps = nrun*statevars.nwalkers*statevars.ensembles
    statevars.mixcount = 0
    statevars.ismixed = 0
    statevars.burn_complete = False
    statevars.nburn = 0
    statevars.ncomplete = statevars.nburn
    statevars.pcomplete = 0
    statevars.rate = 0
    statevars.ar = 0
    statevars.mintz = -1
    statevars.maxgr = np.inf
    statevars.t0 = time.time()

    for r in range(num_run):
        t1 = time.time()
        jobs = []
        for i,sampler in enumerate(statevars.samplers):
            if sampler.flatlnprobability.shape[0] == 0:
                p1 = statevars.initial_positions[i]
            else:
                p1 = None
            jobs.append(statevars.server.submit(_crunch, (sampler, p1,
                                                          checkinterval)))
            
        for i, j in enumerate(jobs):
            statevars.samplers[i] = j()
            
        t2 = time.time()
        statevars.interval = t2 - t1

        # Use Threading
        ch = CheckThread(convergence_check, statevars.server, statevars.samplers)
        ch.start()

        # Use multiprocessing
        # result = pool.apply_async(convergence_check,
        #                 (statevars.server, statevars.samplers))

        # ch = CheckThread(status_message, statevars)
        # ch.start()
        
        #convergence_check(statevars.server, statevars.samplers)
        # Burn-in complete after maximum G-R statistic first reaches burnGR
        # reset samplers
        if not statevars.burn_complete and statevars.maxgr <= burnGR:
            server.wait()
            ch.join()
            for i, sampler in enumerate(statevars.samplers):
                statevars.initial_positions[i] = \
                    sampler._last_run_mcmc_result[0]
                sampler.reset()
                statevars.samplers[i] = sampler
            msg = (
                "\nDiscarding burn-in now that the chains are marginally "
                "well-mixed\n"
            )
            print(msg)
            statevars.nburn = statevars.ncomplete
            statevars.burn_complete = True

        if statevars.mixcount >= 5:
            server.wait()
            ch.join()
            tf = time.time()
            tdiff = tf - statevars.t0
            tdiff,units = utils.time_print(tdiff)
            msg = (
                "\nChains are well-mixed after {:d} steps! MCMC completed in "
                "{:3.1f} {:s}"
            ).format(statevars.ncomplete, tdiff, units)
            print(msg)
            break

    server.destroy()
            
    print("\n")        
    if statevars.ismixed and statevars.mixcount < 5: 
        msg = (
            "MCMC: WARNING: chains did not pass 5 consecutive convergence "
            "tests. They may be marginally well=mixed."
        )
        print(msg)
    elif not statevars.ismixed: 
        msg = (
            "MCMC: WARNING: chains did not pass convergence tests. They are "
            "likely not well-mixed."
        )
        print(msg)
        
    df = pd.DataFrame(
        statevars.tchains.reshape(statevars.ndim,
            statevars.tchains.shape[1]*statevars.tchains.shape[2]).transpose(),
        columns=likelihood.list_vary_params())
    df['lnprobability'] = np.hstack(statevars.lnprob)

    ch.join()
    
    return df