def __init__(self, mp_context=None): super(ProcessExecutor, self).__init__() if not mp_context: # noinspection PyUnresolvedReferences from multiprocess import get_context mp_context = get_context() self._ctx = mp_context
def ctx(self, starter): self._ctx = get_context(starter) self._ctxname = starter registered_objects = self._registry[starter] for obj, name in registered_objects: if name is None: name = obj.__name__ setattr(self._ctx, name, obj)
def __init__(self, device, ctx=mp.get_context("spawn")): """ Constructor method """ super().__init__(device, ctx) self.display_pose = None self.display_pose_queue = ClearableMPQueue(2, ctx=self.ctx) self.pose_process = None
def test_spawned_sharedmem(): array = SpawnSharedmemManager.shared_ndarray((10, 3), 'd') mm = SpawnSharedmemManager(array) ctx = mp.get_context('spawn') with ctx.Pool(processes=max(1, mp.cpu_count() - 3), initializer=init_shared_array, initargs=(mm,)) as p: p.map(write_to_array, range(len(array))) p.close() p.join() expected_result = np.repeat(np.arange(array.shape[0]), array.shape[1]).reshape(array.shape) assert np.all(expected_result == array), 'Forked shared-memory writing failed'
def ranch(nprocs: int, fn: Callable, *args, parent_rank: int = 0, host_rank: int = 0, ctx=None, **kwargs): '''Launch `fn(*args, **kwargs)` to `nprocs` spawned processes. Local rank, global rank (multiple hosts), and world size are set in os.environ['LOCAL_RANK','RANK','WORLD_SIZE'] respectively. Parent process can participate as rank_{parent_rank}. Can optionally apply a context manager `ctx` around `fn(...)`.''' assert nprocs > 0, ValueError( "nprocs: # of processes to launch must be > 0") children_ranks = list(range(nprocs)) if parent_rank is not None: assert 0 <= parent_rank < nprocs, ValueError( f"Out of range parent_rank:{parent_rank}, must be 0 <= parent_rank < {nprocs}" ) children_ranks.pop(parent_rank) multiproc_ctx = mp.get_context("forkserver") procs = [] try: os.environ["WORLD_SIZE"], base_rank = str(nprocs), host_rank * nprocs target_fn = _contextualize(fn, ctx) if ctx else fn for rank in children_ranks: os.environ.update({ "LOCAL_RANK": str(rank), "RANK": str(rank + base_rank) }) p = multiproc_ctx.Process(target=target_fn, args=args, kwargs=kwargs) procs.append(p) p.start() if parent_rank is not None: # also run it in current process at a rank os.environ.update({ "LOCAL_RANK": str(parent_rank), "RANK": str(parent_rank + base_rank) }) return target_fn(*args, **kwargs) else: return procs except Exception as e: raise Exception(e) from e finally: for k in ["WORLD_SIZE", "RANK", "LOCAL_RANK"]: os.environ.pop(k, None) for p in procs: p.join()
def start_processing_pool(): if config.multicore_processing: if config.max_core_usage == 'max': workers = mp.cpu_count() else: workers = config.max_core_usage ctx = mp.get_context( 'forkserver') # This stops segmentation fault for MacOS if config.verbose: print(f"\nStarting processing pool using {workers} cores.") return ctx.Pool(processes=workers)
def __init__(self, max_workers=None, mp_context=None, initializer=None, initargs=()): super(ProcessPoolExecutor, self).__init__() if not mp_context: # noinspection PyUnresolvedReferences from multiprocess import get_context mp_context = get_context() self._max_workers = max_workers self._ctx = mp_context self._initializer = initializer self._initargs = initargs self.lock = self._ctx.Lock() self.pool = None
def __init__(self, device, ctx=mp.get_context("spawn")): """ Constructor method """ self.device = device self.ctx = ctx res = self.device.im_size self.frame_shared = mp.Array(ctypes.c_uint8, res[1] * res[0] * 3) self.frame = np.frombuffer(self.frame_shared.get_obj(), dtype="uint8").reshape(res[1], res[0], 3) self.frame_time_shared = mp.Array(ctypes.c_double, 1) self.frame_time = np.frombuffer(self.frame_time_shared.get_obj(), dtype="d") self.q_to_process = ClearableMPQueue(ctx=self.ctx) self.q_from_process = ClearableMPQueue(ctx=self.ctx) self.write_frame_queue = ClearableMPQueue(ctx=self.ctx) self.capture_process = None self.writer_process = None
def estimate_param_scan(estimator, X, param_sets, evaluate=None, evaluate_args=None, failfast=True, return_estimators=False, n_jobs=1, progress_reporter=None, show_progress=True, return_exceptions=False): """ Runs multiple estimations using a list of parameter settings Parameters ---------- estimator : Estimator object or class An estimator object that provides an estimate(X, **params) function. If only a class is provided here, the Estimator objects will be constructed with default parameter settings, and the parameter settings from param_sets for each estimation. If you want to specify other parameter settings for those parameters not specified in param_sets, construct an Estimator before and pass the object. param_sets : iterable over dictionaries An iterable that provides parameter settings. Each element defines a parameter set, for which an estimation will be run using these parameters in estimate(X, **params). All other parameter settings will be taken from the default settings in the estimator object. evaluate : str or list of str, optional The given methods or properties will be called on the estimated models, and their results will be returned instead of the full models. This may be useful for reducing memory overhead. evaluate_args: iterable of iterable, optional Arguments to be passed to evaluated methods. Note, that size has to match to the size of evaluate. failfast : bool If True, will raise an exception when estimation failed with an exception or trying to calls a method that doesn't exist. If False, will simply return None in these cases. return_estimators: bool If True, return a list estimators in addition to the models. show_progress: bool if the given estimator supports show_progress interface, we set the flag prior doing estimations. return_exceptions: bool, default=False if failfast is False while this setting is True, returns the exception thrown at the actual grid element, instead of None. Returns ------- models : list of model objects or evaluated function values A list of estimated models in the same order as param_sets. If evaluate is given, each element will contain the results from these method evaluations. estimators (optional) : list of estimator objects. These are returned only if return_estimators=True Examples -------- Estimate a maximum likelihood Markov model at lag times 1, 2, 3. >>> from pyemma.msm.estimators import MaximumLikelihoodMSM, BayesianMSM >>> >>> dtraj = [0,0,1,2,1,0,1,0,1,2,2,0,0,0,1,1,2,1,0,0,1,2,1,0,0,0,1,1,0,1,2] # mini-trajectory >>> param_sets=param_grid({'lag': [1,2,3]}) >>> >>> estimate_param_scan(MaximumLikelihoodMSM, dtraj, param_sets, evaluate='timescales') [array([ 1.24113168, 0.77454377]), array([ 2.65266698, 1.42909842]), array([ 5.34810405, 1.14784446])] Now we also want to get samples of the timescales using the BayesianMSM. >>> estimate_param_scan(MaximumLikelihoodMSM, dtraj, param_sets, failfast=False, ... evaluate=['timescales', 'timescales_samples']) # doctest: +SKIP [[array([ 1.24113168, 0.77454377]), None], [array([ 2.48226337, 1.54908754]), None], [array([ 3.72339505, 2.32363131]), None]] We get Nones because the MaximumLikelihoodMSM estimator doesn't provide timescales_samples. Use for example a Bayesian estimator for that. Now we also want to get samples of the timescales using the BayesianMSM. >>> estimate_param_scan(BayesianMSM, dtraj, param_sets, show_progress=False, ... evaluate=['timescales', 'sample_f'], evaluate_args=((), ('timescales', ))) # doctest: +SKIP [[array([ 1.24357685, 0.77609028]), [array([ 1.5963252 , 0.73877883]), array([ 1.29915847, 0.49004912]), array([ 0.90058583, 0.73841786]), ... ]] """ # make sure we have an estimator object estimator = get_estimator(estimator) if hasattr(estimator, 'show_progress'): estimator.show_progress = show_progress if n_jobs is None: from pyemma._base.parallel import get_n_jobs n_jobs = get_n_jobs(logger=getattr(estimator, 'logger', None)) # if we want to return estimators, make clones. Otherwise just copy references. # For parallel processing we always need clones. # Also if the Estimator is its own Model, we have to clone. from pyemma._base.model import Model if (return_estimators or n_jobs > 1 or n_jobs is None or isinstance(estimator, Model)): estimators = [clone_estimator(estimator) for _ in param_sets] else: estimators = [estimator for _ in param_sets] # only show progress of parameter study. if hasattr(estimators[0], 'show_progress'): for e in estimators: e.show_progress = False # if we evaluate, make sure we have a list of functions to evaluate if _types.is_string(evaluate): evaluate = [evaluate] if _types.is_string(evaluate_args): evaluate_args = [evaluate_args] if evaluate is not None and evaluate_args is not None and len(evaluate) != len(evaluate_args): raise ValueError("length mismatch: evaluate ({}) and evaluate_args ({})".format(len(evaluate), len(evaluate_args))) logger_available = hasattr(estimators[0], 'logger') if logger_available: logger = estimators[0].logger if progress_reporter is None: from unittest.mock import MagicMock ctx = progress_reporter = MagicMock() callback = None else: ctx = progress_reporter._progress_context('param-scan') callback = lambda _: progress_reporter._progress_update(1, stage='param-scan') progress_reporter._progress_register(len(estimators), stage='param-scan', description="estimating %s" % str(estimator.__class__.__name__)) # TODO: test on win, osx if n_jobs > 1 and os.name == 'posix': if logger_available: logger.debug('estimating %s with n_jobs=%s', estimator, n_jobs) # iterate over parameter settings limit_threads = True task_iter = ((estimator, param_set, X, evaluate, evaluate_args, failfast, return_exceptions, limit_threads) for estimator, param_set in zip(estimators, param_sets)) if system() == "Linux": pool = get_context("spawn").Pool(processes=n_jobs) else: pool = get_context().Pool(processes=n_jobs) args = list(task_iter) from contextlib import closing def error_callback(*args, **kw): if failfast: # TODO: can we be specific here? eg. obtain the stack of the actual process or is this the master proc? raise Exception('something failed') with closing(pool), ctx: res_async = [pool.apply_async(_estimate_param_scan_worker, a, callback=callback, error_callback=error_callback) for a in args] res = [x.get() for x in res_async] # if n_jobs=1 don't invoke the pool, but directly dispatch the iterator else: if logger_available: logger.debug('estimating %s with n_jobs=1 because of the setting or ' 'you not have a POSIX system', estimator) res = [] with ctx: for estimator, param_set in zip(estimators, param_sets): res.append(_estimate_param_scan_worker(estimator, param_set, X, evaluate, evaluate_args, failfast, return_exceptions, False)) if progress_reporter is not None: progress_reporter._progress_update(1, stage='param-scan') # done if return_estimators: return res, estimators else: return res
import multiprocess as mp import threading # Threading context so we can use threading as backend as well class ThreadingContext: Event = threading.Event Lock = threading.Lock Thread = threading.Thread # threading doesn't have Array and JoinableQueue, so we take it from multiprocessing. Both are thread-safe. We need # the Process class for the MPIRE insights SyncManager instance. Array = mp.Array JoinableQueue = mp.JoinableQueue Process = mp.Process MP_CONTEXTS = { 'fork': mp.get_context('fork'), 'forkserver': mp.get_context('forkserver'), 'spawn': mp.get_context('spawn'), 'threading': ThreadingContext }
import os import pdb import multiprocessing import pickle import itertools import datetime import operator import time import math import copy # John hack import multiprocess ctx = multiprocess.get_context("spawn") # multiprocess.set_start_method('spawn',force=True) # from multiprocess import Pool as mpPool import matplotlib as M import matplotlib.pyplot as plt import numpy as N import pandas import sklearn.preprocessing import sklearn.decomposition # from scipy.spatial import KDTree import scipy.spatial.distance as ssd from mpl_toolkits.basemap import Basemap import evac.utils as utils from evac.stats.detscores import DetScores
# if it doesn't find the file read the script first without the --use_goodreads argument G = pickle.load(open("pickled_graphs/small_graph.p", "rb")) else: G = nx.DiGraph() books_memory = [] for id_ in metadata_calibre.index: with open(f'books_raw/{id_}.txt', 'r') as book_text: books_memory.append(book_text.readlines()) iterator = list( zip(metadata_calibre.index, metadata_calibre['aliases'], metadata_calibre['clean_title'], books_memory)) ctx = mp.get_context('spawn') num_workers = 20 if not args.use_citation_model else 4 p = ctx.Pool(num_workers) f = book_processer jobs = [] send_models = True for i in range(len(iterator)): if args.use_citation_model: job = p.apply_async(f, [iterator[i], model, tokenizer]) else:
def __init__(self, maxsize=0, ctx=mp.get_context("spawn")): super().__init__(maxsize, ctx=ctx)
def __init__(self, *args, **kwargs): ctx = mp.get_context() super().__init__(*args, **kwargs, ctx=ctx) self.buff = {}
def main(settings, rescue_running=[]): """ Perform the primary loop of building, submitting, monitoring, and analyzing jobs. This function works via a loop of calls to thread.process and thread.interpret for each thread that hasn't terminated, until either the global termination criterion is met or all the individual threads have completed. Parameters ---------- settings : argparse.Namespace Settings namespace object rescue_running : list List of threads passed in from handle_loop_exception, containing running threads. If given, setup is skipped and the function proceeds directly to the main loop. Returns ------- exit_message : str A message indicating the status of ATESA at the end of main """ if not rescue_running: # Implement resample if settings.job_type in ['aimless_shooting', 'committor_analysis' ] and settings.resample: # Store settings object in the working directory for compatibility with analysis/utility scripts if not settings.dont_dump: temp_settings = copy.deepcopy( settings ) # initialize temporary copy of settings to modify temp_settings.__dict__.pop( 'env') # env attribute is not picklable pickle.dump(temp_settings, open(settings.working_directory + '/settings.pkl', 'wb'), protocol=2) # Run resampling if settings.job_type == 'aimless_shooting': utilities.resample(settings, partial=False, full_cvs=settings.full_cvs) if settings.information_error_checking: # update info_err.out if called for by settings information_error.main() elif settings.job_type == 'committor_analysis': resample_committor_analysis.resample_committor_analysis( settings) return 'Resampling complete' # Make working directory if it does not exist, handling overwrite and restart as needed if os.path.exists(settings.working_directory): if settings.overwrite and not settings.restart: if os.path.exists( settings.working_directory + '/cvs.txt'): # a kludge to avoid removing cvs.txt if os.path.exists('ATESA_TEMP_CVS.txt'): raise RuntimeError( 'tried to create temporary file ATESA_TEMP_CVS.txt in directory: ' + os.getcwd() + ', but it already exists. Please move, delete, or rename it.' ) shutil.move(settings.working_directory + '/cvs.txt', 'ATESA_TEMP_CVS.txt') shutil.rmtree(settings.working_directory) os.mkdir(settings.working_directory) if os.path.exists('ATESA_TEMP_CVS.txt' ): # continuation of aforementioned kludge shutil.move('ATESA_TEMP_CVS.txt', settings.working_directory + '/cvs.txt') elif not settings.restart and glob.glob( settings.working_directory + '/*') == [settings.working_directory + '/cvs.txt']: # Occurs when restart = False, overwrite = False, and auto_cvs is used pass elif not settings.restart: raise RuntimeError( 'Working directory ' + settings.working_directory + ' already exists, but overwrite ' '= False and restart = False. Either change one of these two settings or choose a ' 'different working directory.') else: if not settings.restart: os.mkdir(settings.working_directory) else: raise RuntimeError('Working directory ' + settings.working_directory + ' does not yet exist, but ' 'restart = True.') # Store settings object in the working directory for compatibility with analysis/utility scripts if os.path.exists( settings.working_directory + '/settings.pkl'): # for checking for need for resample later previous_settings = pickle.load( open(settings.working_directory + '/settings.pkl', 'rb')) settings.previous_cvs = previous_settings.cvs try: settings.previous_information_error_max_dims = previous_settings.information_error_max_dims except AttributeError: pass try: settings.previous_information_error_lmax_string = previous_settings.information_error_lmax_string except AttributeError: pass if not settings.dont_dump: temp_settings = copy.deepcopy( settings) # initialize temporary copy of settings to modify temp_settings.__dict__.pop( 'env' ) # env attribute is not picklable (update: maybe no longer true, but doesn't matter) pickle.dump(temp_settings, open(settings.working_directory + '/settings.pkl', 'wb'), protocol=2) # Build or load threads allthreads = init_threads(settings) # Move runtime to working directory os.chdir(settings.working_directory) running = allthreads.copy() # to be pruned later by thread.process() attempted_rescue = False # to keep track of general error handling below else: allthreads = pickle.load( open(settings.working_directory + '/restart.pkl', 'rb')) running = rescue_running attempted_rescue = True # Initialize threads with first process step try: if not rescue_running: # if rescue_running, this step has already finished and we just want the while loop for thread in allthreads: running = thread.process(running, settings) except Exception as e: if settings.restart: print( 'The following error occurred while attempting to initialize threads from restart.pkl. It may be ' 'corrupted.') #'If you haven\'t already done so, consider running verify_threads.py to remove corrupted threads from this file.' raise e try: if settings.job_type == 'aimless_shooting' and len( os.sched_getaffinity(0)) > 1: # Initialize Manager for shared data across processes; this is necessary because multiprocessing is being # retrofitted to code designed for serial processing, but it works! manager = Manager() # Setup Managed allthreads list managed_allthreads = [] for thread in allthreads: thread_dict = thread.__dict__ thread_history_dict = thread.history.__dict__ managed_thread = Thread() managed_thread.history = manager.Namespace() managed_thread.__dict__.update(thread_dict) managed_thread.history.__dict__.update(thread_history_dict) managed_allthreads.append(managed_thread) allthreads = manager.list(managed_allthreads) # Setup Managed settings Namespace settings_dict = settings.__dict__ managed_settings = manager.Namespace() # Need to explicitly update every key because of how the Managed Namespace works. # Calling exec is the best way to do this I could find. Updating managed_settings.__dict__ doesn't work. for key in settings_dict.keys(): exec('managed_settings.' + key + ' = settings_dict[key]') # Distribute processes among available core Pool with get_context("spawn").Pool(len(os.sched_getaffinity(0))) as p: p.starmap( main_loop, zip(itertools.repeat(managed_settings), itertools.repeat(allthreads), [[thread] for thread in allthreads])) else: main_loop(settings, allthreads, running) except AttributeError: # os.sched_getaffinity raises AttributeError on non-UNIX systems. main_loop(settings, allthreads, running) ## Deprecated thread pool # pool = ThreadPool(len(allthreads)) # func = partial(main_loop, settings) # results = pool.map(func, [[thread] for thread in allthreads]) jobtype = factory.jobtype_factory(settings.job_type) jobtype.cleanup(settings) return 'ATESA run exiting normally'
import threading # If multiprocess is installed we want to use that as it has more capabilities than regular multiprocessing (e.g., # pickling lambdas en functions located in __main__) try: import multiprocess as mp except ImportError: import multiprocessing as mp # Threading context so we can use threading as backend as well class ThreadingContext: Event = threading.Event Lock = threading.Lock Thread = threading.Thread # threading doesn't have Array and JoinableQueue, so we take it from multiprocessing. Both are thread-safe Array = mp.Array JoinableQueue = mp.JoinableQueue MP_CONTEXTS = {'fork': mp.get_context('fork'), 'forkserver': mp.get_context('forkserver'), 'spawn': mp.get_context('spawn'), 'threading': ThreadingContext}
def _execute( workflow: Workflow, workers: int, resources: Dict[str, int], context: Optional[BaseContext] = None, ): # noqa if context is None: context = mp.get_context() buffer = 10 manager = context.Manager() q_set = QueueSet(manager) tasks = {task.task_id: task for task in workflow.to_task_list()} resource_manager = ResourceManager(resources) ref_manager = ReferenceManager(tasks=tasks, storage=workflow.storage) running: List[str] = [] try: # started workers ws: List[Worker] = [] for _ in range(workers): w = Worker(q_set, workflow.storage, context) w.run() ws.append(w) try: while len(tasks) > 0 or len(running) > 0: for worker in ws: if not worker.is_alive(): raise Termination("Detected unexpectedly dead worker ") try: while True: msg: Message = q_set.q_out.get_nowait() assert msg.kind in (Kind.DONE, Kind.GENERATED) if msg.kind == Kind.DONE: running.remove(msg.content["task"].task_id) resource_manager.remove(msg.content["task"]) ref_manager.remove(msg.content["task"]) elif msg.kind == Kind.GENERATED: running.remove(msg.content["task"].task_id) resource_manager.remove(msg.content["task"]) new_tasks: Dict[str, Task] = msg.content["tasks"] for task_id, task in new_tasks.items(): if task_id not in running: tasks[task_id] = task ref_manager.add(task) ref_manager.remove(msg.content["task"]) except queue.Empty: pass try: msg = q_set.q_err.get_nowait() logger.error("raise[task_id={}]".format( msg.content["task_id"])) print(msg.content["trace"]) raise Termination("Raised error on task_id={}".format( msg.content["task_id"]) + "trace:\n" + msg.content["trace"]) except queue.Empty: pass if q_set.q_in.qsize() >= buffer: time.sleep(0.2) continue next_tasks = OrderedDict() for task in tasks.values(): if task.task_id in running: continue if is_completed(task, workflow.storage): continue if q_set.q_in.qsize() >= buffer: next_tasks[task.task_id] = task continue inputs = flatten(task.input()) dependent_tasks_to_execute = OrderedDict() for inp in inputs: if exists_output(inp, workflow.storage): continue dependent_tasks_to_execute[ inp.src_task.task_id] = inp.src_task # Case if there is any in-complete task if len(dependent_tasks_to_execute) > 0: for key, value in dependent_tasks_to_execute.items(): next_tasks[key] = value next_tasks[task.task_id] = task continue if not resource_manager.is_runnable(task): continue q_set.q_in.put( Message(kind=Kind.RUN, content={"task": task})) running.append(task.task_id) resource_manager.add(task) tasks = next_tasks time.sleep(0.2) finally: time.sleep(1) shutdown_all(ws) finally: manager.shutdown()
def ranch(nprocs: int, fn: Callable, *args, caller_rank: int = 0, gather: bool = True, ctx: AbstractContextManager = None, need: str = "", imports="", **kwargs): """ Execute `fn(\*args, \*\*kwargs)` distributedly in `nprocs` processes. User can serialize over objects and functions, spell out import statements, manage execution context, gather results, and the parent process can participate as one of the workers. If `caller_rank` is `0 <= caller_rank < nprocs`, only `nprocs - 1` processes will be forked, and the caller process will be a worker to run its share of `fn(..)`. If `caller_rank` is ``None``, `nprocs` processes will be forked. Inside each worker process, its relative rank among all workers is set up in `os.environ['LOCAL_RANK']`, and the total number of workers is set up in `os.environ['LOCAL_WORLD_SIZE']`, both as strings. Then import statements in `imports`, followed by any objects/functions in `need`, are brought into the python global namespace. Then, context manager `ctx` is applied around the call `fn(\*args, \*\*kwargs)`. Return value of each worker can be gathered in a list (indexed by the process's rank) and returned to the caller of `ranch()`. Args: nprocs: Number of processes to fork. Visible as a string in `os.environ['LOCAL_WORLD_SIZE']` in all worker processes. fn: Function to execute on the worker pool \*args: Positional arguments by values to `fn(\*args....)` \*\*kwargs: Named parameters to `fn(x=..., y=....)` caller_rank: Rank of the parent process. ``0 <= caller_rank < nprocs`` to join, ``None`` to opt out. Default to ``0``. In distributed data parallel, 0 means the leading process. gather: if ``True``, `ranch` will return a list of return values from each worker, indexed by their ranks. If ``False``, and if 'caller_rank' is not None (meaning parent process is a worker), `ranch()` will return whatever the parent process' `fn(...)` returns. ctx: User defined context manager to be used in a 'with'-clause around the 'fn(...)' call in worker processes. Subclassed from AbstractContextManager, ctx needs to define '__enter__()' and '__exit__()' methods. need: Space-separated names of objects/functions to be serialized over to the subprocesses. imports: A multiline string of `import` statements to execute in the subprocesses before `fn()` execution. Supported formats: * `import x, y, z as zoo` * `from A import x` * `from A import z as zoo` * `from A import x, y, z as zoo` * Not supported: `from A import (x, y)` Returns: ``None``, or list of results from worker processes, indexed by their `LOCAL_RANK`: ``[res_0, res_1, .... res_{nprocs-1}]`` """ assert nprocs > 0, ValueError( "nprocs: # of processes to launch must be > 0") children_ranks = list(range(nprocs)) if caller_rank is not None: assert 0 <= caller_rank < nprocs, ValueError( f"Invalid caller_rank {caller_rank}, must satisfy 0 <= caller_rank < {nprocs}" ) children_ranks.pop(caller_rank) multiproc_ctx, procs = mp.get_context("spawn"), [] result_list = multiproc_ctx.Manager().list([None] * nprocs) if gather else None try: # pass globals in this process to subprocess via fn's wrapper, 'target_fn' env = {k: sys.modules['__main__'].__dict__[k] for k in need.split()} for rank in children_ranks: target_fn = _contextualize(rank, nprocs, fn, cm=ctx, l=result_list, env=env, imports=imports) p = multiproc_ctx.Process(target=target_fn, args=args, kwargs=kwargs) procs.append(p) p.start() p_res = (_contextualize( caller_rank, nprocs, fn, cm=ctx, l=result_list, env=env, imports=imports))(*args, ** kwargs) if caller_rank is not None else None for p in procs: p.join() return result_list if gather else p_res finally: for p in procs: p.terminate(), p.join()
def main(): # argument parsing via argparse # the ascii help image help_image = "█▀▀█ ░▀░ █░█ █░░█\n" "█░░█ ▀█▀ ▄▀▄ █▄▄█\n" "█▀▀▀ ▀▀▀ ▀░▀ ▄▄▄█\n" help_text = 'pixy: unbiased estimates of pi, dxy, and fst from VCFs with invariant sites' version = '1.2.5.beta1' citation = 'Korunes, KL and K Samuk. pixy: Unbiased estimation of nucleotide diversity and divergence in the presence of missing data. Mol Ecol Resour. 2021 Jan 16. doi: 10.1111/1755-0998.13326.' # initialize arguments parser = argparse.ArgumentParser( description=help_image + help_text + '\n' + version, formatter_class=argparse.RawTextHelpFormatter) parser._action_groups.pop() required = parser.add_argument_group('required arguments') additional = parser.add_argument_group('in addition, one of') optional = parser.add_argument_group('optional arguments') required.add_argument( '--stats', nargs='+', choices=['pi', 'dxy', 'fst'], help= 'List of statistics to calculate from the VCF, separated by spaces.\ne.g. \"--stats pi fst\" will result in pi and fst calculations.', required=True) required.add_argument( '--vcf', type=str, nargs='?', help='Path to the input VCF (bgzipped and tabix indexed).', required=True) required.add_argument( '--populations', type=str, nargs='?', help= 'Path to a headerless tab separated populations file with columns [SampleID Population].', required=True) additional.add_argument( '--window_size', type=int, nargs='?', help= 'Window size in base pairs over which to calculate stats.\nAutomatically determines window coordinates/bounds (see additional options below).', required=False) additional.add_argument( '--bed_file', type=str, nargs='?', help= 'Path to a headerless .BED file with columns [chrom chromStart chromEnd].\nManually defines window bounds, which can be heterogenous in size.', required=False) optional.add_argument( '--n_cores', type=int, nargs='?', default=1, help='Number of CPUs to utilize for parallel processing (default=1).', required=False) optional.add_argument( '--output_folder', type=str, nargs='?', default='', help= 'Folder where output will be written, e.g. path/to/output_folder.\nDefaults to current working directory.', required=False) optional.add_argument( '--output_prefix', type=str, nargs='?', default='pixy', help= 'Optional prefix for output file(s), with no slashes.\ne.g. \"--output_prefix output\" will result in [output folder]/output_pi.txt. \nDefaults to \'pixy\'.', required=False) optional.add_argument( '--chromosomes', type=str, nargs='?', default='all', help= 'A single-quoted, comma separated list of chromosomes where stats will be calculated. \ne.g. --chromosomes \'X,1,2\' will restrict stats to chromosomes X, 1, and 2.\nDefaults to all chromosomes in the VCF.', required=False) optional.add_argument( '--interval_start', type=str, nargs='?', help= 'The start of an interval over which to calculate stats.\nOnly valid when calculating over a single chromosome.\nDefaults to 1.', required=False) optional.add_argument( '--interval_end', type=str, nargs='?', help= 'The end of the interval over which to calculate stats.\nOnly valid when calculating over a single chromosome.\nDefaults to the last position for a chromosome.', required=False) optional.add_argument( '--sites_file', type=str, nargs='?', help= 'Path to a headerless tab separated file with columns [CHROM POS].\nThis defines sites where summary stats should be calculated.\nCan be combined with the --bed_file and --window_size options.', required=False) optional.add_argument( '--chunk_size', type=int, nargs='?', default=100000, help= 'Approximate number of sites to read from VCF at any given time (default=100000).\nLarger numbers reduce I/O operations at the cost of memory.', required=False) optional.add_argument( '--fst_type', choices=['wc', 'hudson'], default='wc', help= 'FST estimator to use, one of either: \n\'wc\' (Weir and Cockerham 1984) or\n\'hudson\' (Hudson 1992, Bhatia et al. 2013) \nDefaults to \'wc\'.', required=False) optional.add_argument( '--bypass_invariant_check', choices=['yes', 'no'], default='no', help= 'Allow computation of stats without invariant sites (default=no).\nWill result in wildly incorrect estimates most of the time.\nUse with extreme caution!', required=False) optional.add_argument('--version', action='version', version=help_image + 'version ' + version, help='Print the version of pixy in use.') optional.add_argument('--citation', action='version', version=citation, help='Print the citation for pixy.') optional.add_argument('--silent', action='store_true', help='Suppress all console output.') optional.add_argument('--debug', action='store_true', help=argparse.SUPPRESS) optional.add_argument('--keep_temp_file', action='store_true', help=argparse.SUPPRESS) # catch arguments from the command line # automatically uncommented when a release is built args = parser.parse_args() # if not running in debug mode, suppress traceback if not args.debug: sys.tracebacklimit = 0 # if running in silent mode, suppress output if args.silent: sys.stdout = open(os.devnull, "w") # validate arguments with the check_and_validate_args fuction # returns parsed populaion, chromosome, and sample info print("[pixy] pixy " + version) print("[pixy] See documentation at https://pixy.readthedocs.io/en/latest/") popnames, popindices, chrom_list, IDs, temp_file, output_folder, output_prefix, bed_df, sites_df = pixy.core.check_and_validate_args( args) print("\n[pixy] Preparing for calculation of summary statistics: " + ', '.join(map(str, args.stats))) if 'fst' in args.stats: if args.fst_type == 'wc': fst_cite = 'Weir and Cockerham (1984)' elif args.fst_type == 'hudson': fst_cite = 'Hudson (1992)' print("[pixy] Using " + fst_cite + "\'s estimator of FST.") print("[pixy] Data set contains " + str(len(popnames)) + " population(s), " + str(len(chrom_list)) + " chromosome(s), and " + str(len(IDs)) + " sample(s)") if args.window_size is not None: print("[pixy] Window size: " + str(args.window_size) + " bp") if args.bed_file is not None: print("[pixy] Windows sourced from: " + args.bed_file) if args.sites_file is not None: print("[pixy] Calculations restricted to sites in: " + args.sites_file) print('') # time the calculations start_time = time.time() print("[pixy] Started calculations at " + time.strftime("%H:%M:%S on %Y-%m-%d", time.localtime(start_time))) print("[pixy] Using " + str(args.n_cores) + " out of " + str(mp.cpu_count()) + " available CPU cores\n") # if in mc mode, set up multiprocessing if (args.n_cores > 1): # use forking context on linux, and spawn otherwise (macOS) if sys.platform == "linux": ctx = mp.get_context("fork") else: ctx = mp.get_context("spawn") # set up the multiprocessing manager, queue, and process pool manager = ctx.Manager() q = manager.Queue() pool = ctx.Pool(int(args.n_cores)) # a listener function for writing a temp file # used to write output in multicore mode def listener(q, temp_file): with open(temp_file, 'a') as f: while 1: m = q.get() if m == 'kill': break f.write(str(m) + '\n') f.flush() # launch the watcher function for collecting output watcher = pool.apply_async(listener, args=( q, temp_file, )) # begin processing each chromosome for chromosome in chrom_list: print("[pixy] Processing chromosome/contig " + chromosome + "...") # if not using a bed file, build windows manually if args.bed_file is None: # if an interval is specified, assign it if args.interval_start is not None: interval_start = int(args.interval_start) interval_end = int(args.interval_end) # otherwise, get the interval from the VCF's POS column else: if args.sites_file is None: chrom_max = subprocess.check_output( "tabix " + args.vcf + " " + chromosome + " | cut -f 2 | tail -n 1", shell=True).decode("utf-8").split() interval_start = 1 interval_end = int(chrom_max[0]) else: sites_df = pandas.read_csv(args.sites_file, sep='\t', usecols=[0, 1], names=['CHROM', 'POS']) sites_pre_list = sites_df[sites_df['CHROM'] == chromosome] sites_pre_list = sorted(sites_pre_list['POS'].tolist()) interval_start = min(sites_pre_list) interval_end = max(sites_pre_list) # final check if intervals are valid try: if (interval_start > interval_end): raise ValueError() except ValueError as e: raise Exception( "[pixy] ERROR: The specified interval start (" + str(interval_start) + ") exceeds the interval end (" + str(interval_end) + ")") from e targ_region = chromosome + ":" + str(interval_start) + "-" + str( interval_end) print("[pixy] Calculating statistics for region " + targ_region + "...") # Determine list of windows over which to compute stats # window size window_size = args.window_size # in the case were window size = 1, AND there is a sites file, use the sites file as the 'windows' if window_size == 1 and args.sites_file is not None: window_list = [ list(a) for a in zip(sites_pre_list, sites_pre_list) ] else: # if the interval is smaller than one window, make a list of length 1 if (interval_end - interval_start) <= window_size: window_pos_1_list = [interval_start] window_pos_2_list = [interval_start + window_size - 1] else: # if the interval_end is not a perfect multiple of the window size # bump the interval_end up to the nearest multiple of window size if not (interval_end % window_size == 0): interval_end = interval_end + ( window_size - (interval_end % window_size)) # create the start and stops for each window window_pos_1_list = [ *range(interval_start, int(interval_end), window_size) ] window_pos_2_list = [ *range(interval_start + (window_size - 1), int(interval_end) + window_size, window_size) ] window_list = [ list(a) for a in zip(window_pos_1_list, window_pos_2_list) ] # Set aggregate to true if 1) the window size is larger than the chunk size OR 2) the window size wasn't specified, but the chrom is longer than the cutoff if (window_size > args.chunk_size) or ( (args.window_size is None) and ((interval_end - interval_start) > args.chunk_size)): aggregate = True else: aggregate = False # if using a bed file, subset the bed file for the current chromosome else: aggregate = False bed_df_chrom = bed_df[bed_df['chrom'] == chromosome] window_list = [ list(a) for a in zip(bed_df_chrom['chromStart'], bed_df_chrom['chromEnd']) ] if (len(window_list) == 0): raise Exception( "[pixy] ERROR: Window creation failed. Ensure that the POS column in the VCF is valid or change --window_size." ) # if aggregating, break down large windows into smaller windows if aggregate: window_list = pixy.core.assign_subwindows_to_windows( window_list, args.chunk_size) # using chunk_size, assign windows to chunks window_list = pixy.core.assign_windows_to_chunks( window_list, args.chunk_size) # if using a sites file, assign sites to chunks, as with windows above if args.sites_file is not None: sites_df = pandas.read_csv(args.sites_file, sep='\t', usecols=[0, 1], names=['CHROM', 'POS']) sites_pre_list = sites_df[sites_df['CHROM'] == chromosome] sites_pre_list = sites_pre_list['POS'].tolist() sites_list = pixy.core.assign_sites_to_chunks( sites_pre_list, args.chunk_size) else: sites_df = None # obtain the list of chunks from the window list chunk_list = [i[2] for i in window_list] chunk_list = list(set(chunk_list)) # if running in mc mode, send the summary stats funtion to the jobs pool if (args.n_cores > 1): # the list of jobs to be launched jobs = [] for chunk in chunk_list: # create a subset of the window list specific to this chunk window_list_chunk = [x for x in window_list if x[2] == chunk] # and for the site list (if it exists) if sites_df is not None: sites_list_chunk = [x for x in sites_list if x[1] == chunk] sites_list_chunk = [x[0] for x in sites_list_chunk] else: sites_list_chunk = None # determine the bounds of the chunk chunk_pos_1 = min(window_list_chunk, key=lambda x: x[1])[0] chunk_pos_2 = max(window_list_chunk, key=lambda x: x[1])[1] # launch a summary stats job for this chunk job = pool.apply_async( pixy.core.compute_summary_stats, args=(args, popnames, popindices, temp_file, chromosome, chunk_pos_1, chunk_pos_2, window_list_chunk, q, sites_list_chunk, aggregate, args.window_size)) jobs.append(job) # launch all the jobs onto the pool for job in jobs: job.get() # if running in single core mode, loop over the function manually elif (args.n_cores == 1): for chunk in chunk_list: # create a subset of the window list specific to this chunk window_list_chunk = [x for x in window_list if x[2] == chunk] # and for the site list (if it exists) if sites_df is not None: sites_list_chunk = [x for x in sites_list if x[1] == chunk] sites_list_chunk = [x[0] for x in sites_list_chunk] else: sites_list_chunk = None # determine the bounds of the chunk chunk_pos_1 = min(window_list_chunk, key=lambda x: x[1])[0] chunk_pos_2 = max(window_list_chunk, key=lambda x: x[1])[1] # don't use the queue (q) when running in single core mode q = "NULL" # compute summary stats for all windows in the chunk window list pixy.core.compute_summary_stats(args, popnames, popindices, temp_file, chromosome, chunk_pos_1, chunk_pos_2, window_list_chunk, q, sites_list_chunk, aggregate, args.window_size) # clean up any remaining jobs and stop the listener if (args.n_cores > 1): q.put('kill') pool.close() pool.join() # split and aggregate temp file to individual files # check if there is any output to process # halt execution if not try: outpanel = pandas.read_csv(temp_file, sep='\t', header=None) except pandas.errors.EmptyDataError: raise Exception( '[pixy] ERROR: pixy failed to write any output. Confirm that your bed/sites files and intervals refer to existing chromosomes and positions in the VCF.' ) # check if particular stats failed to generate output successful_stats = np.unique(outpanel[0]) # if not all requested stats were generated, produce a warning # and then remove the failed stats from the args list if set(args.stats) != set(successful_stats): missing_stats = list(set(args.stats) - set(successful_stats)) print( '\n[pixy] WARNING: pixy failed to find any valid gentoype data to calculate the following summary statistics: ' + ', '.join(missing_stats) + ". No output file will be created for these statistics.") args.stats = set(successful_stats) outpanel[3] = outpanel[3].astype(str) # force chromosome IDs to string outgrouped = outpanel.groupby([0, 3]) #groupby statistic, chromosome # enforce chromosome IDs as strings chrom_list = list(map(str, chrom_list)) if 'pi' in args.stats: stat = 'pi' pi_file = str(output_prefix) + "_pi.txt" if os.path.exists(pi_file): os.remove(pi_file) outfile = open(pi_file, 'a') outfile.write("pop" + "\t" + "chromosome" + "\t" + "window_pos_1" + "\t" + "window_pos_2" + "\t" + "avg_pi" + "\t" + "no_sites" + "\t" + "count_diffs" + "\t" + "count_comparisons" + "\t" + "count_missing" + "\n") if aggregate: #put winsizes back together for each population to make final_window_size for chromosome in chrom_list: outpi = outgrouped.get_group(("pi", chromosome)).reset_index( drop=True) #get this statistic, this chrom only outpi.drop([0, 2], axis=1, inplace=True ) #get rid of "pi" and placeholder (NA) columns outsorted = pixy.core.aggregate_output(outpi, stat, chromosome, window_size, args.fst_type) outsorted.to_csv(outfile, sep="\t", mode='a', header=False, index=False, na_rep='NA') #write else: for chromosome in chrom_list: outpi = outgrouped.get_group(("pi", chromosome)).reset_index( drop=True) #get this statistic, this chrom only outpi.drop([0, 2], axis=1, inplace=True ) #get rid of "pi" and placeholder (NA) columns outsorted = outpi.sort_values([4]) #sort by position # make sure sites, comparisons, missing get written as integers cols = [7, 8, 9, 10] outsorted[cols] = outsorted[cols].astype('Int64') outsorted.to_csv(outfile, sep="\t", mode='a', header=False, index=False, na_rep='NA') #write outfile.close() if 'dxy' in args.stats: stat = 'dxy' dxy_file = str(output_prefix) + "_dxy.txt" if os.path.exists(dxy_file): os.remove(dxy_file) outfile = open(dxy_file, 'a') outfile.write("pop1" + "\t" + "pop2" + "\t" + "chromosome" + "\t" + "window_pos_1" + "\t" + "window_pos_2" + "\t" + "avg_dxy" + "\t" + "no_sites" + "\t" + "count_diffs" + "\t" + "count_comparisons" + "\t" + "count_missing" + "\n") if aggregate: # put winsizes back together for each population to make final_window_size for chromosome in chrom_list: outdxy = outgrouped.get_group(("dxy", chromosome)).reset_index( drop=True) #get this statistic, this chrom only outdxy.drop([0], axis=1, inplace=True) #get rid of "dxy" outsorted = pixy.core.aggregate_output(outdxy, stat, chromosome, window_size, args.fst_type) outsorted.to_csv(outfile, sep="\t", mode='a', header=False, index=False, na_rep='NA') #write else: for chromosome in chrom_list: outdxy = outgrouped.get_group(("dxy", chromosome)).reset_index( drop=True) #get this statistic, this chrom only outdxy.drop([0], axis=1, inplace=True) #get rid of "dxy" outsorted = outdxy.sort_values([4]) #sort by position # make sure sites, comparisons, missing get written as integers cols = [7, 8, 9, 10] outsorted[cols] = outsorted[cols].astype('Int64') outsorted.to_csv(outfile, sep="\t", mode='a', header=False, index=False, na_rep='NA') #write outfile.close() if 'fst' in args.stats: stat = 'fst' fst_file = str(output_prefix) + "_fst.txt" if os.path.exists(fst_file): os.remove(fst_file) outfile = open(fst_file, 'a') outfile.write("pop1" + "\t" + "pop2" + "\t" + "chromosome" + "\t" + "window_pos_1" + "\t" + "window_pos_2" + "\t" + "avg_" + args.fst_type + "_fst" + "\t" + "no_snps" + "\n") # keep track of chrosomes with no fst data chroms_with_no_data = [] if aggregate: #put winsizes back together for each population to make final_window_size for chromosome in chrom_list: # logic to accodomate cases where pi/dxy have stats for a chromosome, but fst does not chromosome_has_data = True # if there are no valid fst estimates, set chromosome_has_data = False try: outfst = outgrouped.get_group( ("fst", chromosome)).reset_index( drop=True) #get this statistic, this chrom only except KeyError: chroms_with_no_data.append(chromosome) chromosome_has_data = False pass if chromosome_has_data: outfst.drop([0], axis=1, inplace=True) #get rid of "fst" outsorted = pixy.core.aggregate_output( outfst, stat, chromosome, window_size, args.fst_type) outsorted = outsorted.iloc[:, : -3] #drop components (for now) outsorted.to_csv(outfile, sep="\t", mode='a', header=False, index=False, na_rep='NA') #write else: for chromosome in chrom_list: # logic to accodomate cases where pi/dxy have stats for a chromosome, but fst does not chromosome_has_data = True # if there are no valid fst estimates, set chromosome_has_data = False try: outfst = outgrouped.get_group( ("fst", chromosome)).reset_index( drop=True) #get this statistic, this chrom only except KeyError: chroms_with_no_data.append(chromosome) chromosome_has_data = False pass if chromosome_has_data: outfst.drop([0], axis=1, inplace=True) #get rid of "fst" outsorted = outfst.sort_values([4]) #sort by position # make sure sites (but not components like pi/dxy) cols = [7] outsorted[cols] = outsorted[cols].astype('Int64') outsorted = outsorted.iloc[:, : -3] #drop components (for now) outsorted.to_csv(outfile, sep="\t", mode='a', header=False, index=False, na_rep='NA') outfile.close() if len(chroms_with_no_data) >= 1: print( '\n[pixy] NOTE: ' + 'The following chromosomes/scaffolds did not have sufficient data to estimate FST: ' + ', '.join(chroms_with_no_data)) # remove the temp file(s) if (args.keep_temp_file is not True): os.remove(temp_file) # confirm output was generated successfully outfolder_files = [ f for f in os.listdir(output_folder) if os.path.isfile(os.path.join(output_folder, f)) ] r = re.compile(".*_dxy.*|.*_pi.*|.*_fst.*") output_files = list(filter(r.match, outfolder_files)) r = re.compile("pixy_tmpfile.*") leftover_tmp_files = list(filter(r.match, outfolder_files)) if len(output_files) == 0: print( '\n[pixy] WARNING: pixy failed to write any output files. Your VCF may not contain valid genotype data, or it was removed via filtering using the specified sites/bed file (if any).' ) # print completion message end_time = time.time() print("\n[pixy] All calculations complete at " + time.strftime("%H:%M:%S on %Y-%m-%d", time.localtime(end_time))) total_time = (time.time() - start_time) print("[pixy] Time elapsed: " + time.strftime("%H:%M:%S", time.gmtime(total_time))) print("[pixy] Output files written to: " + output_folder) if len(leftover_tmp_files) > 0: print("\n[pixy] NOTE: There are pixy temp files in " + output_folder) print( "[pixy] If these are not actively being used (e.g. by another running pixy process), they can be safely deleted." ) print( "\n[pixy] If you use pixy in your research, please cite the following paper:\n[pixy] " + citation) # restore output if args.silent: sys.stdout = sys.__stdout__