예제 #1
0
파일: atmcmc.py 프로젝트: xiaolongma/beat
def init_stage(homepath,
               step,
               stage,
               model,
               n_jobs=1,
               progressbar=False,
               update=None,
               rm_flag=False):
    """
    Examine starting point of sampling, reload stages and initialise steps.
    """
    if stage is not None:
        if stage == '0':
            # continue or start initial stage
            step.stage = int(stage)
            stage_path = os.path.join(homepath, 'stage_%i' % step.stage)
            draws = 1

        elif stage == 'final':
            # continue sampling final stage
            last = backend.get_highest_sampled_stage(homepath)

            logger.info('Loading parameters from completed stage_%i' % last)
            project_dir = os.path.dirname(homepath)
            mode = os.path.basename(homepath)
            step, updates = backend.load_sampler_params(
                project_dir, str(last), mode)

            if update is not None:
                update.apply(updates)

            stage_path = os.path.join(homepath, 'stage_final')
            draws = step.n_steps

        else:
            # continue sampling intermediate
            stage = int(stage)
            logger.info('Loading parameters from completed stage_%i' %
                        (stage - 1))
            project_dir = os.path.dirname(homepath)
            mode = os.path.basename(homepath)
            step, updates = backend.load_sampler_params(
                project_dir, str(stage - 1), mode)

            if update is not None:
                update.apply(updates)

            step.stage += 1

            stage_path = os.path.join(homepath, 'stage_%i' % step.stage)
            draws = step.n_steps

        if rm_flag:
            chains = None
            if os.path.exists(stage_path):
                logger.info('Removing previous sampling results ... '
                            '%s' % stage_path)
                shutil.rmtree(stage_path)
        else:
            with model:
                if os.path.exists(stage_path):
                    # load incomplete stage results
                    logger.info('Reloading existing results ...')
                    mtrace = backend.load(stage_path, model=model)
                    if len(mtrace.chains) > 0:
                        # continue sampling if traces exist
                        logger.info('Checking for corrupted files ...')
                        chains = backend.check_multitrace(
                            mtrace, draws=draws, n_chains=step.n_chains)
                        rest = len(chains) % n_jobs

                        if rest > 0.:
                            logger.info('Fixing %i chains ...' % rest)
                            rest_chains = utility.split_off_list(chains, rest)
                            # process traces that are not a multiple of n_jobs
                            sample_args = {
                                'draws': draws,
                                'step': step,
                                'stage_path': stage_path,
                                'progressbar': progressbar,
                                'model': model,
                                'n_jobs': rest,
                                'chains': rest_chains
                            }

                            _iter_parallel_chains(**sample_args)
                            logger.info('Back to normal!')
                    else:
                        logger.info('Init new trace!')
                        chains = None

                else:
                    logger.info('Init new trace!')
                    chains = None
    else:
        raise Exception('stage has to be not None!')

    return chains, step, update
예제 #2
0
파일: base.py 프로젝트: mingzhaochina/beat
def iter_parallel_chains(draws,
                         step,
                         stage_path,
                         progressbar,
                         model,
                         n_jobs,
                         chains=None,
                         initializer=None,
                         initargs=(),
                         chunksize=None):
    """
    Do Metropolis sampling over all the chains with each chain being
    sampled 'draws' times. Parallel execution according to n_jobs.
    If jobs hang for any reason they are being killed after an estimated
    timeout. The chains in question are being rerun and the estimated timeout
    is added again.

    Parameters
    ----------
    draws : int
        number of steps that are taken within each Markov Chain
    step : step object of the sampler class, e.g.:
        :class:`beat.sampler.Metropolis`, :class:`beat.sampler.SMC`
    stage_path : str
        with absolute path to the directory where to store the sampling results
    progressbar : boolean
        flag for displaying a progressbar
    model : :class:`pymc3.model.Model` instance
        holds definition of the forward problem
    n_jobs : int
        number of jobs to run in parallel, must not be higher than the
        number of CPUs
    chains : list
        of integers to the chain numbers, if None then all chains from the
        step object are sampled
    initializer : function
        to run before execution of each sampling process
    initargs : tuple
        of arguments for the initializer
    chunksize : int
        number of chains to sample within each process

    Returns
    -------
    MultiTrace object
    """
    timeout = 0

    if chains is None:
        chains = list(range(step.n_chains))

    n_chains = len(chains)

    if n_chains == 0:
        mtrace = backend.load_multitrace(dirname=stage_path, model=model)

    # while is necessary if any worker times out - rerun in case
    while n_chains > 0:
        trace_list = []

        logger.info('Initialising %i chain traces ...' % n_chains)
        for chain in chains:
            trace_list.append(backend.TextChain(stage_path, model=model))

        max_int = np.iinfo(np.int32).max
        random_seeds = [randint(max_int) for _ in range(n_chains)]

        work = [
            (draws, step, step.population[step.resampling_indexes[chain]],
             trace, chain, None, progressbar, model, rseed)
            for chain, rseed, trace in zip(chains, random_seeds, trace_list)
        ]

        tps = step.time_per_sample(np.minimum(n_jobs, 10))
        logger.info('Serial time per sample: %f' % tps)

        if chunksize is None:
            if draws < 10:
                chunksize = int(np.ceil(float(n_chains) / n_jobs))
            elif draws > 10 and tps < 1.:
                chunksize = int(np.ceil(float(n_chains) / n_jobs))
            else:
                chunksize = n_jobs

        timeout += int(np.ceil(tps * draws)) * n_jobs + 10

        if n_jobs > 1:
            shared_params = [
                sparam for sparam in step.logp_forw.get_shared()
                if sparam.name in parallel._tobememshared
            ]

            logger.info('Data to be memory shared: %s' %
                        list2string(shared_params))

            if len(shared_params) > 0:
                if len(parallel._shared_memory.keys()) == 0:
                    logger.info('Putting data into shared memory ...')
                    parallel.memshare_sparams(shared_params)
                else:
                    logger.info('Data already in shared memory!')

            else:
                logger.info('No data to be memshared!')

        else:
            logger.info('Not using shared memory.')

        p = parallel.paripool(_sample,
                              work,
                              chunksize=chunksize,
                              timeout=timeout,
                              nprocs=n_jobs,
                              initializer=initializer,
                              initargs=initargs)

        logger.info('Sampling ...')

        for res in p:
            pass

        # return chain indexes that have been corrupted
        mtrace = backend.load_multitrace(dirname=stage_path, model=model)
        corrupted_chains = backend.check_multitrace(mtrace,
                                                    draws=draws,
                                                    n_chains=step.n_chains)

        n_chains = len(corrupted_chains)

        if n_chains > 0:
            logger.warning('%i Chains not finished sampling,'
                           ' restarting ...' % n_chains)

        chains = corrupted_chains

    return mtrace
예제 #3
0
파일: smc.py 프로젝트: sandragharbi/beat
def _iter_parallel_chains(
        draws, step, stage_path, progressbar, model, n_jobs,
        chains=None):
    """
    Do Metropolis sampling over all the chains with each chain being
    sampled 'draws' times. Parallel execution according to n_jobs.
    If jobs hang for any reason they are being killed after an estimated
    timeout. The chains in question are being rerun and the estimated timeout
    is added again.
    """
    timeout = 0

    if chains is None:
        chains = list(range(step.n_chains))

    n_chains = len(chains)

    if n_chains == 0:
        mtrace = backend.load_multitrace(dirname=stage_path, model=model)

    # while is necessary if any worker times out - rerun in case
    while n_chains > 0:
        trace_list = []

        logger.info('Initialising %i chain traces ...' % n_chains)
        for chain in chains:
            trace_list.append(backend.TextChain(stage_path, model=model))

        max_int = np.iinfo(np.int32).max
        random_seeds = [randint(max_int) for _ in range(n_chains)]

        work = [(draws, step, step.population[step.resampling_indexes[chain]],
                trace, chain, None, progressbar, model, rseed)
                for chain, rseed, trace in zip(
                    chains, random_seeds, trace_list)]

        tps = step.time_per_sample(10)

        if draws < 10:
            chunksize = int(np.ceil(float(n_chains) / n_jobs))
            tps += 5.
        elif draws > 10 and tps < 1.:
            chunksize = int(np.ceil(float(n_chains) / n_jobs))
        else:
            chunksize = n_jobs

        timeout += int(np.ceil(tps * draws)) * n_jobs

        p = paripool.paripool(
            _sample, work, chunksize=chunksize, timeout=timeout, nprocs=n_jobs)

        logger.info('Sampling ...')

        for res in p:
            pass

        # return chain indexes that have been corrupted
        mtrace = backend.load_multitrace(dirname=stage_path, model=model)
        corrupted_chains = backend.check_multitrace(
            mtrace, draws=draws, n_chains=step.n_chains)

        n_chains = len(corrupted_chains)

        if n_chains > 0:
            logger.warning(
                '%i Chains not finished sampling,'
                ' restarting ...' % n_chains)

        chains = corrupted_chains

    return mtrace