def init_stage(homepath, step, stage, model, n_jobs=1, progressbar=False, update=None, rm_flag=False): """ Examine starting point of sampling, reload stages and initialise steps. """ if stage is not None: if stage == '0': # continue or start initial stage step.stage = int(stage) stage_path = os.path.join(homepath, 'stage_%i' % step.stage) draws = 1 elif stage == 'final': # continue sampling final stage last = backend.get_highest_sampled_stage(homepath) logger.info('Loading parameters from completed stage_%i' % last) project_dir = os.path.dirname(homepath) mode = os.path.basename(homepath) step, updates = backend.load_sampler_params( project_dir, str(last), mode) if update is not None: update.apply(updates) stage_path = os.path.join(homepath, 'stage_final') draws = step.n_steps else: # continue sampling intermediate stage = int(stage) logger.info('Loading parameters from completed stage_%i' % (stage - 1)) project_dir = os.path.dirname(homepath) mode = os.path.basename(homepath) step, updates = backend.load_sampler_params( project_dir, str(stage - 1), mode) if update is not None: update.apply(updates) step.stage += 1 stage_path = os.path.join(homepath, 'stage_%i' % step.stage) draws = step.n_steps if rm_flag: chains = None if os.path.exists(stage_path): logger.info('Removing previous sampling results ... ' '%s' % stage_path) shutil.rmtree(stage_path) else: with model: if os.path.exists(stage_path): # load incomplete stage results logger.info('Reloading existing results ...') mtrace = backend.load(stage_path, model=model) if len(mtrace.chains) > 0: # continue sampling if traces exist logger.info('Checking for corrupted files ...') chains = backend.check_multitrace( mtrace, draws=draws, n_chains=step.n_chains) rest = len(chains) % n_jobs if rest > 0.: logger.info('Fixing %i chains ...' % rest) rest_chains = utility.split_off_list(chains, rest) # process traces that are not a multiple of n_jobs sample_args = { 'draws': draws, 'step': step, 'stage_path': stage_path, 'progressbar': progressbar, 'model': model, 'n_jobs': rest, 'chains': rest_chains } _iter_parallel_chains(**sample_args) logger.info('Back to normal!') else: logger.info('Init new trace!') chains = None else: logger.info('Init new trace!') chains = None else: raise Exception('stage has to be not None!') return chains, step, update
def iter_parallel_chains(draws, step, stage_path, progressbar, model, n_jobs, chains=None, initializer=None, initargs=(), chunksize=None): """ Do Metropolis sampling over all the chains with each chain being sampled 'draws' times. Parallel execution according to n_jobs. If jobs hang for any reason they are being killed after an estimated timeout. The chains in question are being rerun and the estimated timeout is added again. Parameters ---------- draws : int number of steps that are taken within each Markov Chain step : step object of the sampler class, e.g.: :class:`beat.sampler.Metropolis`, :class:`beat.sampler.SMC` stage_path : str with absolute path to the directory where to store the sampling results progressbar : boolean flag for displaying a progressbar model : :class:`pymc3.model.Model` instance holds definition of the forward problem n_jobs : int number of jobs to run in parallel, must not be higher than the number of CPUs chains : list of integers to the chain numbers, if None then all chains from the step object are sampled initializer : function to run before execution of each sampling process initargs : tuple of arguments for the initializer chunksize : int number of chains to sample within each process Returns ------- MultiTrace object """ timeout = 0 if chains is None: chains = list(range(step.n_chains)) n_chains = len(chains) if n_chains == 0: mtrace = backend.load_multitrace(dirname=stage_path, model=model) # while is necessary if any worker times out - rerun in case while n_chains > 0: trace_list = [] logger.info('Initialising %i chain traces ...' % n_chains) for chain in chains: trace_list.append(backend.TextChain(stage_path, model=model)) max_int = np.iinfo(np.int32).max random_seeds = [randint(max_int) for _ in range(n_chains)] work = [ (draws, step, step.population[step.resampling_indexes[chain]], trace, chain, None, progressbar, model, rseed) for chain, rseed, trace in zip(chains, random_seeds, trace_list) ] tps = step.time_per_sample(np.minimum(n_jobs, 10)) logger.info('Serial time per sample: %f' % tps) if chunksize is None: if draws < 10: chunksize = int(np.ceil(float(n_chains) / n_jobs)) elif draws > 10 and tps < 1.: chunksize = int(np.ceil(float(n_chains) / n_jobs)) else: chunksize = n_jobs timeout += int(np.ceil(tps * draws)) * n_jobs + 10 if n_jobs > 1: shared_params = [ sparam for sparam in step.logp_forw.get_shared() if sparam.name in parallel._tobememshared ] logger.info('Data to be memory shared: %s' % list2string(shared_params)) if len(shared_params) > 0: if len(parallel._shared_memory.keys()) == 0: logger.info('Putting data into shared memory ...') parallel.memshare_sparams(shared_params) else: logger.info('Data already in shared memory!') else: logger.info('No data to be memshared!') else: logger.info('Not using shared memory.') p = parallel.paripool(_sample, work, chunksize=chunksize, timeout=timeout, nprocs=n_jobs, initializer=initializer, initargs=initargs) logger.info('Sampling ...') for res in p: pass # return chain indexes that have been corrupted mtrace = backend.load_multitrace(dirname=stage_path, model=model) corrupted_chains = backend.check_multitrace(mtrace, draws=draws, n_chains=step.n_chains) n_chains = len(corrupted_chains) if n_chains > 0: logger.warning('%i Chains not finished sampling,' ' restarting ...' % n_chains) chains = corrupted_chains return mtrace
def _iter_parallel_chains( draws, step, stage_path, progressbar, model, n_jobs, chains=None): """ Do Metropolis sampling over all the chains with each chain being sampled 'draws' times. Parallel execution according to n_jobs. If jobs hang for any reason they are being killed after an estimated timeout. The chains in question are being rerun and the estimated timeout is added again. """ timeout = 0 if chains is None: chains = list(range(step.n_chains)) n_chains = len(chains) if n_chains == 0: mtrace = backend.load_multitrace(dirname=stage_path, model=model) # while is necessary if any worker times out - rerun in case while n_chains > 0: trace_list = [] logger.info('Initialising %i chain traces ...' % n_chains) for chain in chains: trace_list.append(backend.TextChain(stage_path, model=model)) max_int = np.iinfo(np.int32).max random_seeds = [randint(max_int) for _ in range(n_chains)] work = [(draws, step, step.population[step.resampling_indexes[chain]], trace, chain, None, progressbar, model, rseed) for chain, rseed, trace in zip( chains, random_seeds, trace_list)] tps = step.time_per_sample(10) if draws < 10: chunksize = int(np.ceil(float(n_chains) / n_jobs)) tps += 5. elif draws > 10 and tps < 1.: chunksize = int(np.ceil(float(n_chains) / n_jobs)) else: chunksize = n_jobs timeout += int(np.ceil(tps * draws)) * n_jobs p = paripool.paripool( _sample, work, chunksize=chunksize, timeout=timeout, nprocs=n_jobs) logger.info('Sampling ...') for res in p: pass # return chain indexes that have been corrupted mtrace = backend.load_multitrace(dirname=stage_path, model=model) corrupted_chains = backend.check_multitrace( mtrace, draws=draws, n_chains=step.n_chains) n_chains = len(corrupted_chains) if n_chains > 0: logger.warning( '%i Chains not finished sampling,' ' restarting ...' % n_chains) chains = corrupted_chains return mtrace