Exemplo n.º 1
0
 def __init__(self, mp_context=None):
     super(ProcessExecutor, self).__init__()
     if not mp_context:
         # noinspection PyUnresolvedReferences
         from multiprocess import get_context
         mp_context = get_context()
     self._ctx = mp_context
Exemplo n.º 2
0
 def ctx(self, starter):
     self._ctx = get_context(starter)
     self._ctxname = starter
     registered_objects = self._registry[starter]
     for obj, name in registered_objects:
         if name is None:
             name = obj.__name__
         setattr(self._ctx, name, obj)
Exemplo n.º 3
0
    def __init__(self, device, ctx=mp.get_context("spawn")):
        """ Constructor method
        """

        super().__init__(device, ctx)
        self.display_pose = None
        self.display_pose_queue = ClearableMPQueue(2, ctx=self.ctx)
        self.pose_process = None
Exemplo n.º 4
0
def test_spawned_sharedmem():
    array = SpawnSharedmemManager.shared_ndarray((10, 3), 'd')
    mm = SpawnSharedmemManager(array)
    ctx = mp.get_context('spawn')

    with ctx.Pool(processes=max(1, mp.cpu_count() - 3), initializer=init_shared_array, initargs=(mm,)) as p:
        p.map(write_to_array, range(len(array)))
    p.close()
    p.join()

    expected_result = np.repeat(np.arange(array.shape[0]), array.shape[1]).reshape(array.shape)
    assert np.all(expected_result == array), 'Forked shared-memory writing failed'
Exemplo n.º 5
0
def ranch(nprocs: int,
          fn: Callable,
          *args,
          parent_rank: int = 0,
          host_rank: int = 0,
          ctx=None,
          **kwargs):
    '''Launch `fn(*args, **kwargs)` to `nprocs` spawned processes. Local rank, global rank (multiple hosts),
       and world size are set in os.environ['LOCAL_RANK','RANK','WORLD_SIZE'] respectively.
       Parent process can participate as rank_{parent_rank}.
       Can optionally apply a context manager `ctx` around `fn(...)`.'''
    assert nprocs > 0, ValueError(
        "nprocs: # of processes to launch must be > 0")
    children_ranks = list(range(nprocs))
    if parent_rank is not None:
        assert 0 <= parent_rank < nprocs, ValueError(
            f"Out of range parent_rank:{parent_rank}, must be 0 <= parent_rank < {nprocs}"
        )
        children_ranks.pop(parent_rank)

    multiproc_ctx = mp.get_context("forkserver")

    procs = []
    try:
        os.environ["WORLD_SIZE"], base_rank = str(nprocs), host_rank * nprocs
        target_fn = _contextualize(fn, ctx) if ctx else fn

        for rank in children_ranks:
            os.environ.update({
                "LOCAL_RANK": str(rank),
                "RANK": str(rank + base_rank)
            })
            p = multiproc_ctx.Process(target=target_fn,
                                      args=args,
                                      kwargs=kwargs)
            procs.append(p)
            p.start()

        if parent_rank is not None:  # also run it in current process at a rank
            os.environ.update({
                "LOCAL_RANK": str(parent_rank),
                "RANK": str(parent_rank + base_rank)
            })
            return target_fn(*args, **kwargs)
        else:
            return procs
    except Exception as e:
        raise Exception(e) from e
    finally:
        for k in ["WORLD_SIZE", "RANK", "LOCAL_RANK"]:
            os.environ.pop(k, None)
        for p in procs:
            p.join()
Exemplo n.º 6
0
    def start_processing_pool():
        if config.multicore_processing:
            if config.max_core_usage == 'max':
                workers = mp.cpu_count()
            else:
                workers = config.max_core_usage

            ctx = mp.get_context(
                'forkserver')  # This stops segmentation fault for MacOS
            if config.verbose:
                print(f"\nStarting processing pool using {workers} cores.")
            return ctx.Pool(processes=workers)
Exemplo n.º 7
0
    def __init__(self, max_workers=None, mp_context=None, initializer=None,
                 initargs=()):
        super(ProcessPoolExecutor, self).__init__()
        if not mp_context:
            # noinspection PyUnresolvedReferences
            from multiprocess import get_context
            mp_context = get_context()

        self._max_workers = max_workers
        self._ctx = mp_context
        self._initializer = initializer
        self._initargs = initargs
        self.lock = self._ctx.Lock()
        self.pool = None
    def __init__(self, device, ctx=mp.get_context("spawn")):
        """ Constructor method
        """

        self.device = device
        self.ctx = ctx

        res = self.device.im_size
        self.frame_shared = mp.Array(ctypes.c_uint8, res[1] * res[0] * 3)
        self.frame = np.frombuffer(self.frame_shared.get_obj(),
                                   dtype="uint8").reshape(res[1], res[0], 3)
        self.frame_time_shared = mp.Array(ctypes.c_double, 1)
        self.frame_time = np.frombuffer(self.frame_time_shared.get_obj(),
                                        dtype="d")

        self.q_to_process = ClearableMPQueue(ctx=self.ctx)
        self.q_from_process = ClearableMPQueue(ctx=self.ctx)
        self.write_frame_queue = ClearableMPQueue(ctx=self.ctx)

        self.capture_process = None
        self.writer_process = None
Exemplo n.º 9
0
def estimate_param_scan(estimator, X, param_sets, evaluate=None, evaluate_args=None, failfast=True,
                        return_estimators=False, n_jobs=1, progress_reporter=None, show_progress=True,
                        return_exceptions=False):
    """ Runs multiple estimations using a list of parameter settings

    Parameters
    ----------
    estimator : Estimator object or class
        An estimator object that provides an estimate(X, **params) function.
        If only a class is provided here, the Estimator objects will be
        constructed with default parameter settings, and the parameter settings
        from param_sets for each estimation. If you want to specify other
        parameter settings for those parameters not specified in param_sets,
        construct an Estimator before and pass the object.

    param_sets : iterable over dictionaries
        An iterable that provides parameter settings. Each element defines a
        parameter set, for which an estimation will be run using these
        parameters in estimate(X, **params). All other parameter settings will
        be taken from the default settings in the estimator object.

    evaluate : str or list of str, optional
        The given methods or properties will be called on the estimated
        models, and their results will be returned instead of the full models.
        This may be useful for reducing memory overhead.

    evaluate_args: iterable of iterable, optional
        Arguments to be passed to evaluated methods. Note, that size has to match to the size of evaluate.

    failfast : bool
        If True, will raise an exception when estimation failed with an exception
        or trying to calls a method that doesn't exist. If False, will simply
        return None in these cases.

    return_estimators: bool
        If True, return a list estimators in addition to the models.

    show_progress: bool
        if the given estimator supports show_progress interface, we set the flag
        prior doing estimations.

    return_exceptions: bool, default=False
        if failfast is False while this setting is True, returns the exception thrown at the actual grid element,
        instead of None.

    Returns
    -------
    models : list of model objects or evaluated function values
        A list of estimated models in the same order as param_sets. If evaluate
        is given, each element will contain the results from these method
        evaluations.

    estimators (optional) : list of estimator objects. These are returned only
        if return_estimators=True

    Examples
    --------

    Estimate a maximum likelihood Markov model at lag times 1, 2, 3.

    >>> from pyemma.msm.estimators import MaximumLikelihoodMSM, BayesianMSM
    >>>
    >>> dtraj = [0,0,1,2,1,0,1,0,1,2,2,0,0,0,1,1,2,1,0,0,1,2,1,0,0,0,1,1,0,1,2]  # mini-trajectory
    >>> param_sets=param_grid({'lag': [1,2,3]})
    >>>
    >>> estimate_param_scan(MaximumLikelihoodMSM, dtraj, param_sets, evaluate='timescales')
    [array([ 1.24113168,  0.77454377]), array([ 2.65266698,  1.42909842]), array([ 5.34810405,  1.14784446])]

    Now we also want to get samples of the timescales using the BayesianMSM.
    >>> estimate_param_scan(MaximumLikelihoodMSM, dtraj, param_sets, failfast=False,
    ...     evaluate=['timescales', 'timescales_samples']) # doctest: +SKIP
    [[array([ 1.24113168,  0.77454377]), None], [array([ 2.48226337,  1.54908754]), None], [array([ 3.72339505,  2.32363131]), None]]

    We get Nones because the MaximumLikelihoodMSM estimator doesn't provide timescales_samples. Use for example
    a Bayesian estimator for that.

    Now we also want to get samples of the timescales using the BayesianMSM.
    >>> estimate_param_scan(BayesianMSM, dtraj, param_sets, show_progress=False,
    ...     evaluate=['timescales', 'sample_f'], evaluate_args=((), ('timescales', ))) # doctest: +SKIP
    [[array([ 1.24357685,  0.77609028]), [array([ 1.5963252 ,  0.73877883]), array([ 1.29915847,  0.49004912]), array([ 0.90058583,  0.73841786]), ... ]]

    """
    # make sure we have an estimator object
    estimator = get_estimator(estimator)
    if hasattr(estimator, 'show_progress'):
        estimator.show_progress = show_progress

    if n_jobs is None:
        from pyemma._base.parallel import get_n_jobs
        n_jobs = get_n_jobs(logger=getattr(estimator, 'logger', None))

    # if we want to return estimators, make clones. Otherwise just copy references.
    # For parallel processing we always need clones.
    # Also if the Estimator is its own Model, we have to clone.
    from pyemma._base.model import Model
    if (return_estimators or
        n_jobs > 1 or n_jobs is None or
        isinstance(estimator, Model)):
        estimators = [clone_estimator(estimator) for _ in param_sets]
    else:
        estimators = [estimator for _ in param_sets]

    # only show progress of parameter study.
    if hasattr(estimators[0], 'show_progress'):
        for e in estimators:
            e.show_progress = False

    # if we evaluate, make sure we have a list of functions to evaluate
    if _types.is_string(evaluate):
        evaluate = [evaluate]
    if _types.is_string(evaluate_args):
        evaluate_args = [evaluate_args]

    if evaluate is not None and evaluate_args is not None and len(evaluate) != len(evaluate_args):
        raise ValueError("length mismatch: evaluate ({}) and evaluate_args ({})".format(len(evaluate), len(evaluate_args)))

    logger_available = hasattr(estimators[0], 'logger')
    if logger_available:
        logger = estimators[0].logger
    if progress_reporter is None:
        from unittest.mock import MagicMock
        ctx = progress_reporter = MagicMock()
        callback = None
    else:
        ctx = progress_reporter._progress_context('param-scan')
        callback = lambda _: progress_reporter._progress_update(1, stage='param-scan')

        progress_reporter._progress_register(len(estimators), stage='param-scan',
                                             description="estimating %s" % str(estimator.__class__.__name__))

    # TODO: test on win, osx
    if n_jobs > 1 and os.name == 'posix':
        if logger_available:
            logger.debug('estimating %s with n_jobs=%s', estimator, n_jobs)
        # iterate over parameter settings
        limit_threads = True
        task_iter = ((estimator,
                      param_set, X,
                      evaluate,
                      evaluate_args,
                      failfast,
                      return_exceptions,
                      limit_threads)
                     for estimator, param_set in zip(estimators, param_sets))

        if system() == "Linux":
            pool = get_context("spawn").Pool(processes=n_jobs)
        else:
            pool = get_context().Pool(processes=n_jobs)
        args = list(task_iter)

        from contextlib import closing

        def error_callback(*args, **kw):
            if failfast:
                # TODO: can we be specific here? eg. obtain the stack of the actual process or is this the master proc?
                raise Exception('something failed')

        with closing(pool), ctx:
            res_async = [pool.apply_async(_estimate_param_scan_worker, a, callback=callback,
                                          error_callback=error_callback) for a in args]
            res = [x.get() for x in res_async]

    # if n_jobs=1 don't invoke the pool, but directly dispatch the iterator
    else:
        if logger_available:
            logger.debug('estimating %s with n_jobs=1 because of the setting or '
                         'you not have a POSIX system', estimator)
        res = []
        with ctx:
            for estimator, param_set in zip(estimators, param_sets):
                res.append(_estimate_param_scan_worker(estimator, param_set, X,
                                                       evaluate, evaluate_args, failfast, return_exceptions, False))
                if progress_reporter is not None:
                    progress_reporter._progress_update(1, stage='param-scan')

    # done
    if return_estimators:
        return res, estimators
    else:
        return res
Exemplo n.º 10
0
import multiprocess as mp
import threading


# Threading context so we can use threading as backend as well
class ThreadingContext:

    Event = threading.Event
    Lock = threading.Lock
    Thread = threading.Thread

    # threading doesn't have Array and JoinableQueue, so we take it from multiprocessing. Both are thread-safe. We need
    # the Process class for the MPIRE insights SyncManager instance.
    Array = mp.Array
    JoinableQueue = mp.JoinableQueue
    Process = mp.Process


MP_CONTEXTS = {
    'fork': mp.get_context('fork'),
    'forkserver': mp.get_context('forkserver'),
    'spawn': mp.get_context('spawn'),
    'threading': ThreadingContext
}
Exemplo n.º 11
0
import os
import pdb
import multiprocessing
import pickle
import itertools
import datetime
import operator
import time
import math
import copy

# John hack

import multiprocess
ctx = multiprocess.get_context("spawn")
# multiprocess.set_start_method('spawn',force=True)
# from multiprocess import Pool as mpPool

import matplotlib as M
import matplotlib.pyplot as plt
import numpy as N
import pandas
import sklearn.preprocessing
import sklearn.decomposition
# from scipy.spatial import KDTree
import scipy.spatial.distance as ssd
from mpl_toolkits.basemap import Basemap

import evac.utils as utils
from evac.stats.detscores import DetScores
Exemplo n.º 12
0
        # if it doesn't find the file read the script first without the --use_goodreads argument
        G = pickle.load(open("pickled_graphs/small_graph.p", "rb"))
    else:
        G = nx.DiGraph()

    books_memory = []

    for id_ in metadata_calibre.index:
        with open(f'books_raw/{id_}.txt', 'r') as book_text:
            books_memory.append(book_text.readlines())

    iterator = list(
        zip(metadata_calibre.index, metadata_calibre['aliases'],
            metadata_calibre['clean_title'], books_memory))

    ctx = mp.get_context('spawn')

    num_workers = 20 if not args.use_citation_model else 4

    p = ctx.Pool(num_workers)

    f = book_processer

    jobs = []
    send_models = True

    for i in range(len(iterator)):

        if args.use_citation_model:
            job = p.apply_async(f, [iterator[i], model, tokenizer])
        else:
Exemplo n.º 13
0
    def __init__(self, maxsize=0, ctx=mp.get_context("spawn")):

        super().__init__(maxsize, ctx=ctx)
Exemplo n.º 14
0
 def __init__(self, *args, **kwargs):
     ctx = mp.get_context()
     super().__init__(*args, **kwargs, ctx=ctx)
     self.buff = {}
Exemplo n.º 15
0
def main(settings, rescue_running=[]):
    """
    Perform the primary loop of building, submitting, monitoring, and analyzing jobs.

    This function works via a loop of calls to thread.process and thread.interpret for each thread that hasn't
    terminated, until either the global termination criterion is met or all the individual threads have completed.

    Parameters
    ----------
    settings : argparse.Namespace
        Settings namespace object
    rescue_running : list
        List of threads passed in from handle_loop_exception, containing running threads. If given, setup is skipped and
        the function proceeds directly to the main loop.

    Returns
    -------
    exit_message : str
        A message indicating the status of ATESA at the end of main

    """

    if not rescue_running:
        # Implement resample
        if settings.job_type in ['aimless_shooting', 'committor_analysis'
                                 ] and settings.resample:
            # Store settings object in the working directory for compatibility with analysis/utility scripts
            if not settings.dont_dump:
                temp_settings = copy.deepcopy(
                    settings
                )  # initialize temporary copy of settings to modify
                temp_settings.__dict__.pop(
                    'env')  # env attribute is not picklable
                pickle.dump(temp_settings,
                            open(settings.working_directory + '/settings.pkl',
                                 'wb'),
                            protocol=2)
            # Run resampling
            if settings.job_type == 'aimless_shooting':
                utilities.resample(settings,
                                   partial=False,
                                   full_cvs=settings.full_cvs)
                if settings.information_error_checking:  # update info_err.out if called for by settings
                    information_error.main()
            elif settings.job_type == 'committor_analysis':
                resample_committor_analysis.resample_committor_analysis(
                    settings)
            return 'Resampling complete'

        # Make working directory if it does not exist, handling overwrite and restart as needed
        if os.path.exists(settings.working_directory):
            if settings.overwrite and not settings.restart:
                if os.path.exists(
                        settings.working_directory +
                        '/cvs.txt'):  # a kludge to avoid removing cvs.txt
                    if os.path.exists('ATESA_TEMP_CVS.txt'):
                        raise RuntimeError(
                            'tried to create temporary file ATESA_TEMP_CVS.txt in directory: '
                            + os.getcwd() +
                            ', but it already exists. Please move, delete, or rename it.'
                        )
                    shutil.move(settings.working_directory + '/cvs.txt',
                                'ATESA_TEMP_CVS.txt')
                shutil.rmtree(settings.working_directory)
                os.mkdir(settings.working_directory)
                if os.path.exists('ATESA_TEMP_CVS.txt'
                                  ):  # continuation of aforementioned kludge
                    shutil.move('ATESA_TEMP_CVS.txt',
                                settings.working_directory + '/cvs.txt')
            elif not settings.restart and glob.glob(
                    settings.working_directory +
                    '/*') == [settings.working_directory + '/cvs.txt']:
                # Occurs when restart = False, overwrite = False, and auto_cvs is used
                pass
            elif not settings.restart:
                raise RuntimeError(
                    'Working directory ' + settings.working_directory +
                    ' already exists, but overwrite '
                    '= False and restart = False. Either change one of these two settings or choose a '
                    'different working directory.')
        else:
            if not settings.restart:
                os.mkdir(settings.working_directory)
            else:
                raise RuntimeError('Working directory ' +
                                   settings.working_directory +
                                   ' does not yet exist, but '
                                   'restart = True.')

        # Store settings object in the working directory for compatibility with analysis/utility scripts
        if os.path.exists(
                settings.working_directory +
                '/settings.pkl'):  # for checking for need for resample later
            previous_settings = pickle.load(
                open(settings.working_directory + '/settings.pkl', 'rb'))
            settings.previous_cvs = previous_settings.cvs
            try:
                settings.previous_information_error_max_dims = previous_settings.information_error_max_dims
            except AttributeError:
                pass
            try:
                settings.previous_information_error_lmax_string = previous_settings.information_error_lmax_string
            except AttributeError:
                pass
        if not settings.dont_dump:
            temp_settings = copy.deepcopy(
                settings)  # initialize temporary copy of settings to modify
            temp_settings.__dict__.pop(
                'env'
            )  # env attribute is not picklable (update: maybe no longer true, but doesn't matter)
            pickle.dump(temp_settings,
                        open(settings.working_directory + '/settings.pkl',
                             'wb'),
                        protocol=2)

        # Build or load threads
        allthreads = init_threads(settings)

        # Move runtime to working directory
        os.chdir(settings.working_directory)

        running = allthreads.copy()  # to be pruned later by thread.process()
        attempted_rescue = False  # to keep track of general error handling below
    else:
        allthreads = pickle.load(
            open(settings.working_directory + '/restart.pkl', 'rb'))
        running = rescue_running
        attempted_rescue = True

    # Initialize threads with first process step
    try:
        if not rescue_running:  # if rescue_running, this step has already finished and we just want the while loop
            for thread in allthreads:
                running = thread.process(running, settings)
    except Exception as e:
        if settings.restart:
            print(
                'The following error occurred while attempting to initialize threads from restart.pkl. It may be '
                'corrupted.')
            #'If you haven\'t already done so, consider running verify_threads.py to remove corrupted threads from this file.'
        raise e

    try:
        if settings.job_type == 'aimless_shooting' and len(
                os.sched_getaffinity(0)) > 1:
            # Initialize Manager for shared data across processes; this is necessary because multiprocessing is being
            # retrofitted to code designed for serial processing, but it works!
            manager = Manager()

            # Setup Managed allthreads list
            managed_allthreads = []
            for thread in allthreads:
                thread_dict = thread.__dict__
                thread_history_dict = thread.history.__dict__
                managed_thread = Thread()
                managed_thread.history = manager.Namespace()
                managed_thread.__dict__.update(thread_dict)
                managed_thread.history.__dict__.update(thread_history_dict)
                managed_allthreads.append(managed_thread)
            allthreads = manager.list(managed_allthreads)

            # Setup Managed settings Namespace
            settings_dict = settings.__dict__
            managed_settings = manager.Namespace()
            # Need to explicitly update every key because of how the Managed Namespace works.
            # Calling exec is the best way to do this I could find. Updating managed_settings.__dict__ doesn't work.
            for key in settings_dict.keys():
                exec('managed_settings.' + key + ' = settings_dict[key]')

            # Distribute processes among available core Pool
            with get_context("spawn").Pool(len(os.sched_getaffinity(0))) as p:
                p.starmap(
                    main_loop,
                    zip(itertools.repeat(managed_settings),
                        itertools.repeat(allthreads),
                        [[thread] for thread in allthreads]))
        else:
            main_loop(settings, allthreads, running)
    except AttributeError:  # os.sched_getaffinity raises AttributeError on non-UNIX systems.
        main_loop(settings, allthreads, running)

    ## Deprecated thread pool
    # pool = ThreadPool(len(allthreads))
    # func = partial(main_loop, settings)
    # results = pool.map(func, [[thread] for thread in allthreads])

    jobtype = factory.jobtype_factory(settings.job_type)
    jobtype.cleanup(settings)

    return 'ATESA run exiting normally'
Exemplo n.º 16
0
import threading

# If multiprocess is installed we want to use that as it has more capabilities than regular multiprocessing (e.g.,
# pickling lambdas en functions located in __main__)
try:
    import multiprocess as mp
except ImportError:
    import multiprocessing as mp


# Threading context so we can use threading as backend as well
class ThreadingContext:

    Event = threading.Event
    Lock = threading.Lock
    Thread = threading.Thread

    # threading doesn't have Array and JoinableQueue, so we take it from multiprocessing. Both are thread-safe
    Array = mp.Array
    JoinableQueue = mp.JoinableQueue


MP_CONTEXTS = {'fork': mp.get_context('fork'),
               'forkserver': mp.get_context('forkserver'),
               'spawn': mp.get_context('spawn'),
               'threading': ThreadingContext}
Exemplo n.º 17
0
def _execute(
    workflow: Workflow,
    workers: int,
    resources: Dict[str, int],
    context: Optional[BaseContext] = None,
):  # noqa

    if context is None:
        context = mp.get_context()

    buffer = 10

    manager = context.Manager()

    q_set = QueueSet(manager)

    tasks = {task.task_id: task for task in workflow.to_task_list()}

    resource_manager = ResourceManager(resources)

    ref_manager = ReferenceManager(tasks=tasks, storage=workflow.storage)

    running: List[str] = []

    try:

        # started workers
        ws: List[Worker] = []
        for _ in range(workers):
            w = Worker(q_set, workflow.storage, context)
            w.run()
            ws.append(w)

        try:

            while len(tasks) > 0 or len(running) > 0:

                for worker in ws:
                    if not worker.is_alive():
                        raise Termination("Detected unexpectedly dead worker ")

                try:
                    while True:
                        msg: Message = q_set.q_out.get_nowait()

                        assert msg.kind in (Kind.DONE, Kind.GENERATED)

                        if msg.kind == Kind.DONE:
                            running.remove(msg.content["task"].task_id)
                            resource_manager.remove(msg.content["task"])
                            ref_manager.remove(msg.content["task"])
                        elif msg.kind == Kind.GENERATED:
                            running.remove(msg.content["task"].task_id)
                            resource_manager.remove(msg.content["task"])

                            new_tasks: Dict[str, Task] = msg.content["tasks"]
                            for task_id, task in new_tasks.items():
                                if task_id not in running:
                                    tasks[task_id] = task
                                    ref_manager.add(task)

                            ref_manager.remove(msg.content["task"])

                except queue.Empty:
                    pass

                try:
                    msg = q_set.q_err.get_nowait()
                    logger.error("raise[task_id={}]".format(
                        msg.content["task_id"]))
                    print(msg.content["trace"])
                    raise Termination("Raised error on task_id={}".format(
                        msg.content["task_id"]) + "trace:\n" +
                                      msg.content["trace"])
                except queue.Empty:
                    pass

                if q_set.q_in.qsize() >= buffer:
                    time.sleep(0.2)
                    continue

                next_tasks = OrderedDict()

                for task in tasks.values():
                    if task.task_id in running:
                        continue

                    if is_completed(task, workflow.storage):
                        continue

                    if q_set.q_in.qsize() >= buffer:
                        next_tasks[task.task_id] = task
                        continue

                    inputs = flatten(task.input())

                    dependent_tasks_to_execute = OrderedDict()

                    for inp in inputs:

                        if exists_output(inp, workflow.storage):
                            continue

                        dependent_tasks_to_execute[
                            inp.src_task.task_id] = inp.src_task

                    # Case if there is any in-complete task
                    if len(dependent_tasks_to_execute) > 0:
                        for key, value in dependent_tasks_to_execute.items():
                            next_tasks[key] = value
                        next_tasks[task.task_id] = task
                        continue

                    if not resource_manager.is_runnable(task):
                        continue

                    q_set.q_in.put(
                        Message(kind=Kind.RUN, content={"task": task}))

                    running.append(task.task_id)

                    resource_manager.add(task)

                tasks = next_tasks

                time.sleep(0.2)

        finally:
            time.sleep(1)
            shutdown_all(ws)
    finally:
        manager.shutdown()
Exemplo n.º 18
0
def ranch(nprocs: int,
          fn: Callable,
          *args,
          caller_rank: int = 0,
          gather: bool = True,
          ctx: AbstractContextManager = None,
          need: str = "",
          imports="",
          **kwargs):
    """ Execute `fn(\*args, \*\*kwargs)` distributedly in `nprocs` processes.  User can
    serialize over objects and functions, spell out import statements, manage execution
    context, gather results, and the parent process can participate as one of the workers.

    If `caller_rank` is `0 <= caller_rank < nprocs`, only `nprocs - 1` processes will be forked, and the caller process will be a worker to run its share of `fn(..)`.

    If `caller_rank` is ``None``, `nprocs` processes will be forked.

    Inside each worker process, its relative rank among all workers is set up in `os.environ['LOCAL_RANK']`, and the total
    number of workers is set up in `os.environ['LOCAL_WORLD_SIZE']`, both as strings.

    Then import statements in `imports`, followed by any objects/functions in `need`, are brought
    into the python global namespace.

    Then, context manager `ctx` is applied around the call `fn(\*args, \*\*kwargs)`.

    Return value of each worker can be gathered in a list (indexed by the process's rank)
    and returned to the caller of `ranch()`.

    Args:
        nprocs: Number of processes to fork.  Visible as a string in `os.environ['LOCAL_WORLD_SIZE']`
            in all worker processes.
        fn: Function to execute on the worker pool
        \*args: Positional arguments by values to `fn(\*args....)`
        \*\*kwargs: Named parameters to `fn(x=..., y=....)`
        caller_rank: Rank of the parent process.  ``0 <= caller_rank < nprocs`` to join, ``None`` to opt out. Default to ``0``.

            In distributed data parallel, 0 means the leading process.
        gather: if ``True``, `ranch` will return a list of return values from each worker, indexed by their ranks.
            If ``False``, and if 'caller_rank' is not None (meaning parent process is a worker),
            `ranch()` will return whatever the parent process' `fn(...)` returns.
        ctx: User defined context manager to be used in a 'with'-clause around the 'fn(...)' call in worker processes.
            Subclassed from AbstractContextManager, ctx needs to define '__enter__()' and '__exit__()' methods.
        need: Space-separated names of objects/functions to be serialized over to the subprocesses.
        imports: A multiline string of `import` statements to execute in the subprocesses
            before `fn()` execution.  Supported formats:

            * `import x, y, z as zoo`

            * `from A import x`

            * `from A import z as zoo`

            * `from A import x, y, z as zoo`

            * Not supported: `from A import (x, y)`

    Returns:
        ``None``, or list of results from worker processes, indexed by their `LOCAL_RANK`: ``[res_0, res_1, .... res_{nprocs-1}]``
    """

    assert nprocs > 0, ValueError(
        "nprocs: # of processes to launch must be > 0")

    children_ranks = list(range(nprocs))
    if caller_rank is not None:
        assert 0 <= caller_rank < nprocs, ValueError(
            f"Invalid caller_rank {caller_rank}, must satisfy 0 <= caller_rank < {nprocs}"
        )
        children_ranks.pop(caller_rank)
    multiproc_ctx, procs = mp.get_context("spawn"), []
    result_list = multiproc_ctx.Manager().list([None] *
                                               nprocs) if gather else None
    try:
        # pass globals in this process to subprocess via fn's wrapper, 'target_fn'
        env = {k: sys.modules['__main__'].__dict__[k] for k in need.split()}
        for rank in children_ranks:
            target_fn = _contextualize(rank,
                                       nprocs,
                                       fn,
                                       cm=ctx,
                                       l=result_list,
                                       env=env,
                                       imports=imports)
            p = multiproc_ctx.Process(target=target_fn,
                                      args=args,
                                      kwargs=kwargs)
            procs.append(p)
            p.start()
        p_res = (_contextualize(
            caller_rank,
            nprocs,
            fn,
            cm=ctx,
            l=result_list,
            env=env,
            imports=imports))(*args, **
                              kwargs) if caller_rank is not None else None
        for p in procs:
            p.join()
        return result_list if gather else p_res
    finally:
        for p in procs:
            p.terminate(), p.join()
Exemplo n.º 19
0
def main():
    # argument parsing via argparse

    # the ascii help image
    help_image = "█▀▀█ ░▀░ █░█ █░░█\n" "█░░█ ▀█▀ ▄▀▄ █▄▄█\n" "█▀▀▀ ▀▀▀ ▀░▀ ▄▄▄█\n"

    help_text = 'pixy: unbiased estimates of pi, dxy, and fst from VCFs with invariant sites'
    version = '1.2.5.beta1'
    citation = 'Korunes, KL and K Samuk. pixy: Unbiased estimation of nucleotide diversity and divergence in the presence of missing data. Mol Ecol Resour. 2021 Jan 16. doi: 10.1111/1755-0998.13326.'

    # initialize arguments
    parser = argparse.ArgumentParser(
        description=help_image + help_text + '\n' + version,
        formatter_class=argparse.RawTextHelpFormatter)

    parser._action_groups.pop()
    required = parser.add_argument_group('required arguments')
    additional = parser.add_argument_group('in addition, one of')
    optional = parser.add_argument_group('optional arguments')

    required.add_argument(
        '--stats',
        nargs='+',
        choices=['pi', 'dxy', 'fst'],
        help=
        'List of statistics to calculate from the VCF, separated by spaces.\ne.g. \"--stats pi fst\" will result in pi and fst calculations.',
        required=True)
    required.add_argument(
        '--vcf',
        type=str,
        nargs='?',
        help='Path to the input VCF (bgzipped and tabix indexed).',
        required=True)
    required.add_argument(
        '--populations',
        type=str,
        nargs='?',
        help=
        'Path to a headerless tab separated populations file with columns [SampleID Population].',
        required=True)

    additional.add_argument(
        '--window_size',
        type=int,
        nargs='?',
        help=
        'Window size in base pairs over which to calculate stats.\nAutomatically determines window coordinates/bounds (see additional options below).',
        required=False)
    additional.add_argument(
        '--bed_file',
        type=str,
        nargs='?',
        help=
        'Path to a headerless .BED file with columns [chrom chromStart chromEnd].\nManually defines window bounds, which can be heterogenous in size.',
        required=False)

    optional.add_argument(
        '--n_cores',
        type=int,
        nargs='?',
        default=1,
        help='Number of CPUs to utilize for parallel processing (default=1).',
        required=False)
    optional.add_argument(
        '--output_folder',
        type=str,
        nargs='?',
        default='',
        help=
        'Folder where output will be written, e.g. path/to/output_folder.\nDefaults to current working directory.',
        required=False)
    optional.add_argument(
        '--output_prefix',
        type=str,
        nargs='?',
        default='pixy',
        help=
        'Optional prefix for output file(s), with no slashes.\ne.g. \"--output_prefix output\" will result in [output folder]/output_pi.txt. \nDefaults to \'pixy\'.',
        required=False)
    optional.add_argument(
        '--chromosomes',
        type=str,
        nargs='?',
        default='all',
        help=
        'A single-quoted, comma separated list of chromosomes where stats will be calculated. \ne.g. --chromosomes \'X,1,2\' will restrict stats to chromosomes X, 1, and 2.\nDefaults to all chromosomes in the VCF.',
        required=False)
    optional.add_argument(
        '--interval_start',
        type=str,
        nargs='?',
        help=
        'The start of an interval over which to calculate stats.\nOnly valid when calculating over a single chromosome.\nDefaults to 1.',
        required=False)
    optional.add_argument(
        '--interval_end',
        type=str,
        nargs='?',
        help=
        'The end of the interval over which to calculate stats.\nOnly valid when calculating over a single chromosome.\nDefaults to the last position for a chromosome.',
        required=False)
    optional.add_argument(
        '--sites_file',
        type=str,
        nargs='?',
        help=
        'Path to a headerless tab separated file with columns [CHROM POS].\nThis defines sites where summary stats should be calculated.\nCan be combined with the --bed_file and --window_size options.',
        required=False)
    optional.add_argument(
        '--chunk_size',
        type=int,
        nargs='?',
        default=100000,
        help=
        'Approximate number of sites to read from VCF at any given time (default=100000).\nLarger numbers reduce I/O operations at the cost of memory.',
        required=False)

    optional.add_argument(
        '--fst_type',
        choices=['wc', 'hudson'],
        default='wc',
        help=
        'FST estimator to use, one of either: \n\'wc\' (Weir and Cockerham 1984) or\n\'hudson\' (Hudson 1992, Bhatia et al. 2013) \nDefaults to \'wc\'.',
        required=False)
    optional.add_argument(
        '--bypass_invariant_check',
        choices=['yes', 'no'],
        default='no',
        help=
        'Allow computation of stats without invariant sites (default=no).\nWill result in wildly incorrect estimates most of the time.\nUse with extreme caution!',
        required=False)
    optional.add_argument('--version',
                          action='version',
                          version=help_image + 'version ' + version,
                          help='Print the version of pixy in use.')
    optional.add_argument('--citation',
                          action='version',
                          version=citation,
                          help='Print the citation for pixy.')
    optional.add_argument('--silent',
                          action='store_true',
                          help='Suppress all console output.')
    optional.add_argument('--debug',
                          action='store_true',
                          help=argparse.SUPPRESS)
    optional.add_argument('--keep_temp_file',
                          action='store_true',
                          help=argparse.SUPPRESS)

    # catch arguments from the command line
    # automatically uncommented when a release is built
    args = parser.parse_args()

    # if not running in debug mode, suppress traceback
    if not args.debug:
        sys.tracebacklimit = 0

    # if running in silent mode, suppress output
    if args.silent:
        sys.stdout = open(os.devnull, "w")

    # validate arguments with the check_and_validate_args fuction
    # returns parsed populaion, chromosome, and sample info
    print("[pixy] pixy " + version)
    print("[pixy] See documentation at https://pixy.readthedocs.io/en/latest/")
    popnames, popindices, chrom_list, IDs, temp_file, output_folder, output_prefix, bed_df, sites_df = pixy.core.check_and_validate_args(
        args)

    print("\n[pixy] Preparing for calculation of summary statistics: " +
          ', '.join(map(str, args.stats)))

    if 'fst' in args.stats:
        if args.fst_type == 'wc':
            fst_cite = 'Weir and Cockerham (1984)'
        elif args.fst_type == 'hudson':
            fst_cite = 'Hudson (1992)'
        print("[pixy] Using " + fst_cite + "\'s estimator of FST.")

    print("[pixy] Data set contains " + str(len(popnames)) +
          " population(s), " + str(len(chrom_list)) + " chromosome(s), and " +
          str(len(IDs)) + " sample(s)")

    if args.window_size is not None:
        print("[pixy] Window size: " + str(args.window_size) + " bp")

    if args.bed_file is not None:
        print("[pixy] Windows sourced from: " + args.bed_file)

    if args.sites_file is not None:
        print("[pixy] Calculations restricted to sites in: " + args.sites_file)

    print('')

    # time the calculations
    start_time = time.time()
    print("[pixy] Started calculations at " +
          time.strftime("%H:%M:%S on %Y-%m-%d", time.localtime(start_time)))
    print("[pixy] Using " + str(args.n_cores) + " out of " +
          str(mp.cpu_count()) + " available CPU cores\n")
    # if in mc mode, set up multiprocessing
    if (args.n_cores > 1):

        # use forking context on linux, and spawn otherwise (macOS)
        if sys.platform == "linux":
            ctx = mp.get_context("fork")
        else:
            ctx = mp.get_context("spawn")

        # set up the multiprocessing manager, queue, and process pool
        manager = ctx.Manager()
        q = manager.Queue()
        pool = ctx.Pool(int(args.n_cores))

        # a listener function for writing a temp file
        # used to write output in multicore mode
        def listener(q, temp_file):

            with open(temp_file, 'a') as f:
                while 1:
                    m = q.get()
                    if m == 'kill':
                        break
                    f.write(str(m) + '\n')
                    f.flush()

        # launch the watcher function for collecting output
        watcher = pool.apply_async(listener, args=(
            q,
            temp_file,
        ))

    # begin processing each chromosome

    for chromosome in chrom_list:

        print("[pixy] Processing chromosome/contig " + chromosome + "...")

        # if not using a bed file, build windows manually
        if args.bed_file is None:

            # if an interval is specified, assign it
            if args.interval_start is not None:
                interval_start = int(args.interval_start)
                interval_end = int(args.interval_end)

            # otherwise, get the interval from the VCF's POS column
            else:
                if args.sites_file is None:
                    chrom_max = subprocess.check_output(
                        "tabix " + args.vcf + " " + chromosome +
                        " | cut -f 2 | tail -n 1",
                        shell=True).decode("utf-8").split()
                    interval_start = 1
                    interval_end = int(chrom_max[0])
                else:
                    sites_df = pandas.read_csv(args.sites_file,
                                               sep='\t',
                                               usecols=[0, 1],
                                               names=['CHROM', 'POS'])
                    sites_pre_list = sites_df[sites_df['CHROM'] == chromosome]
                    sites_pre_list = sorted(sites_pre_list['POS'].tolist())
                    interval_start = min(sites_pre_list)
                    interval_end = max(sites_pre_list)

            # final check if intervals are valid
            try:
                if (interval_start > interval_end):
                    raise ValueError()
            except ValueError as e:
                raise Exception(
                    "[pixy] ERROR: The specified interval start (" +
                    str(interval_start) + ") exceeds the interval end (" +
                    str(interval_end) + ")") from e

            targ_region = chromosome + ":" + str(interval_start) + "-" + str(
                interval_end)

            print("[pixy] Calculating statistics for region " + targ_region +
                  "...")

            # Determine list of windows over which to compute stats

            # window size
            window_size = args.window_size

            # in the case were window size = 1, AND there is a sites file, use the sites file as the 'windows'
            if window_size == 1 and args.sites_file is not None:
                window_list = [
                    list(a) for a in zip(sites_pre_list, sites_pre_list)
                ]
            else:
                # if the interval is smaller than one window, make a list of length 1
                if (interval_end - interval_start) <= window_size:
                    window_pos_1_list = [interval_start]
                    window_pos_2_list = [interval_start + window_size - 1]
                else:
                    # if the interval_end is not a perfect multiple of the window size
                    # bump the interval_end up to the nearest multiple of window size
                    if not (interval_end % window_size == 0):
                        interval_end = interval_end + (
                            window_size - (interval_end % window_size))

                    # create the start and stops for each window
                    window_pos_1_list = [
                        *range(interval_start, int(interval_end), window_size)
                    ]
                    window_pos_2_list = [
                        *range(interval_start + (window_size - 1),
                               int(interval_end) + window_size, window_size)
                    ]

                window_list = [
                    list(a) for a in zip(window_pos_1_list, window_pos_2_list)
                ]

            # Set aggregate to true if 1) the window size is larger than the chunk size OR 2) the window size wasn't specified, but the chrom is longer than the cutoff
            if (window_size > args.chunk_size) or (
                (args.window_size is None) and
                ((interval_end - interval_start) > args.chunk_size)):
                aggregate = True
            else:
                aggregate = False

        # if using a bed file, subset the bed file for the current chromosome
        else:
            aggregate = False
            bed_df_chrom = bed_df[bed_df['chrom'] == chromosome]
            window_list = [
                list(a) for a in zip(bed_df_chrom['chromStart'],
                                     bed_df_chrom['chromEnd'])
            ]

        if (len(window_list) == 0):
            raise Exception(
                "[pixy] ERROR: Window creation failed. Ensure that the POS column in the VCF is valid or change --window_size."
            )

        # if aggregating, break down large windows into smaller windows
        if aggregate:
            window_list = pixy.core.assign_subwindows_to_windows(
                window_list, args.chunk_size)

        # using chunk_size, assign  windows to chunks
        window_list = pixy.core.assign_windows_to_chunks(
            window_list, args.chunk_size)

        # if using a sites file, assign sites to chunks, as with windows above
        if args.sites_file is not None:
            sites_df = pandas.read_csv(args.sites_file,
                                       sep='\t',
                                       usecols=[0, 1],
                                       names=['CHROM', 'POS'])
            sites_pre_list = sites_df[sites_df['CHROM'] == chromosome]
            sites_pre_list = sites_pre_list['POS'].tolist()
            sites_list = pixy.core.assign_sites_to_chunks(
                sites_pre_list, args.chunk_size)
        else:
            sites_df = None

        # obtain the list of chunks from the window list
        chunk_list = [i[2] for i in window_list]
        chunk_list = list(set(chunk_list))

        # if running in mc mode, send the summary stats funtion to the jobs pool
        if (args.n_cores > 1):

            # the list of jobs to be launched
            jobs = []

            for chunk in chunk_list:

                # create a subset of the window list specific to this chunk
                window_list_chunk = [x for x in window_list if x[2] == chunk]

                # and for the site list (if it exists)
                if sites_df is not None:
                    sites_list_chunk = [x for x in sites_list if x[1] == chunk]
                    sites_list_chunk = [x[0] for x in sites_list_chunk]
                else:
                    sites_list_chunk = None

                # determine the bounds of the chunk
                chunk_pos_1 = min(window_list_chunk, key=lambda x: x[1])[0]
                chunk_pos_2 = max(window_list_chunk, key=lambda x: x[1])[1]

                # launch a summary stats job for this chunk
                job = pool.apply_async(
                    pixy.core.compute_summary_stats,
                    args=(args, popnames, popindices, temp_file, chromosome,
                          chunk_pos_1, chunk_pos_2, window_list_chunk, q,
                          sites_list_chunk, aggregate, args.window_size))
                jobs.append(job)

            # launch all the jobs onto the pool
            for job in jobs:
                job.get()

        # if running in single core mode, loop over the function manually
        elif (args.n_cores == 1):

            for chunk in chunk_list:

                # create a subset of the window list specific to this chunk
                window_list_chunk = [x for x in window_list if x[2] == chunk]

                # and for the site list (if it exists)
                if sites_df is not None:
                    sites_list_chunk = [x for x in sites_list if x[1] == chunk]
                    sites_list_chunk = [x[0] for x in sites_list_chunk]
                else:
                    sites_list_chunk = None

                # determine the bounds of the chunk
                chunk_pos_1 = min(window_list_chunk, key=lambda x: x[1])[0]
                chunk_pos_2 = max(window_list_chunk, key=lambda x: x[1])[1]

                # don't use the queue (q) when running in single core mode
                q = "NULL"

                # compute summary stats for all windows in the chunk window list
                pixy.core.compute_summary_stats(args, popnames, popindices,
                                                temp_file, chromosome,
                                                chunk_pos_1, chunk_pos_2,
                                                window_list_chunk, q,
                                                sites_list_chunk, aggregate,
                                                args.window_size)

    # clean up any remaining jobs and stop the listener
    if (args.n_cores > 1):
        q.put('kill')
        pool.close()
        pool.join()

    # split and aggregate temp file to individual files

    # check if there is any output to process
    # halt execution if not
    try:
        outpanel = pandas.read_csv(temp_file, sep='\t', header=None)
    except pandas.errors.EmptyDataError:
        raise Exception(
            '[pixy] ERROR: pixy failed to write any output. Confirm that your bed/sites files and intervals refer to existing chromosomes and positions in the VCF.'
        )

    # check if particular stats failed to generate output
    successful_stats = np.unique(outpanel[0])

    # if not all requested stats were generated, produce a warning
    # and then remove the failed stats from the args list
    if set(args.stats) != set(successful_stats):
        missing_stats = list(set(args.stats) - set(successful_stats))
        print(
            '\n[pixy] WARNING: pixy failed to find any valid gentoype data to calculate the following summary statistics: '
            + ', '.join(missing_stats) +
            ". No output file will be created for these statistics.")
        args.stats = set(successful_stats)

    outpanel[3] = outpanel[3].astype(str)  # force chromosome IDs to string
    outgrouped = outpanel.groupby([0, 3])  #groupby statistic, chromosome

    # enforce chromosome IDs as strings
    chrom_list = list(map(str, chrom_list))

    if 'pi' in args.stats:
        stat = 'pi'
        pi_file = str(output_prefix) + "_pi.txt"

        if os.path.exists(pi_file):
            os.remove(pi_file)

        outfile = open(pi_file, 'a')
        outfile.write("pop" + "\t" + "chromosome" + "\t" + "window_pos_1" +
                      "\t" + "window_pos_2" + "\t" + "avg_pi" + "\t" +
                      "no_sites" + "\t" + "count_diffs" + "\t" +
                      "count_comparisons" + "\t" + "count_missing" + "\n")

        if aggregate:  #put winsizes back together for each population to make final_window_size

            for chromosome in chrom_list:
                outpi = outgrouped.get_group(("pi", chromosome)).reset_index(
                    drop=True)  #get this statistic, this chrom only
                outpi.drop([0, 2], axis=1, inplace=True
                           )  #get rid of "pi" and placeholder (NA) columns
                outsorted = pixy.core.aggregate_output(outpi, stat, chromosome,
                                                       window_size,
                                                       args.fst_type)
                outsorted.to_csv(outfile,
                                 sep="\t",
                                 mode='a',
                                 header=False,
                                 index=False,
                                 na_rep='NA')  #write

        else:
            for chromosome in chrom_list:
                outpi = outgrouped.get_group(("pi", chromosome)).reset_index(
                    drop=True)  #get this statistic, this chrom only
                outpi.drop([0, 2], axis=1, inplace=True
                           )  #get rid of "pi" and placeholder (NA) columns
                outsorted = outpi.sort_values([4])  #sort by position
                # make sure sites, comparisons, missing get written as integers
                cols = [7, 8, 9, 10]
                outsorted[cols] = outsorted[cols].astype('Int64')
                outsorted.to_csv(outfile,
                                 sep="\t",
                                 mode='a',
                                 header=False,
                                 index=False,
                                 na_rep='NA')  #write

        outfile.close()

    if 'dxy' in args.stats:
        stat = 'dxy'
        dxy_file = str(output_prefix) + "_dxy.txt"

        if os.path.exists(dxy_file):
            os.remove(dxy_file)

        outfile = open(dxy_file, 'a')
        outfile.write("pop1" + "\t" + "pop2" + "\t" + "chromosome" + "\t" +
                      "window_pos_1" + "\t" + "window_pos_2" + "\t" +
                      "avg_dxy" + "\t" + "no_sites" + "\t" + "count_diffs" +
                      "\t" + "count_comparisons" + "\t" + "count_missing" +
                      "\n")

        if aggregate:  # put winsizes back together for each population to make final_window_size

            for chromosome in chrom_list:
                outdxy = outgrouped.get_group(("dxy", chromosome)).reset_index(
                    drop=True)  #get this statistic, this chrom only
                outdxy.drop([0], axis=1, inplace=True)  #get rid of "dxy"
                outsorted = pixy.core.aggregate_output(outdxy, stat,
                                                       chromosome, window_size,
                                                       args.fst_type)
                outsorted.to_csv(outfile,
                                 sep="\t",
                                 mode='a',
                                 header=False,
                                 index=False,
                                 na_rep='NA')  #write

        else:
            for chromosome in chrom_list:
                outdxy = outgrouped.get_group(("dxy", chromosome)).reset_index(
                    drop=True)  #get this statistic, this chrom only
                outdxy.drop([0], axis=1, inplace=True)  #get rid of "dxy"
                outsorted = outdxy.sort_values([4])  #sort by position
                # make sure sites, comparisons, missing get written as integers
                cols = [7, 8, 9, 10]
                outsorted[cols] = outsorted[cols].astype('Int64')
                outsorted.to_csv(outfile,
                                 sep="\t",
                                 mode='a',
                                 header=False,
                                 index=False,
                                 na_rep='NA')  #write

        outfile.close()

    if 'fst' in args.stats:
        stat = 'fst'
        fst_file = str(output_prefix) + "_fst.txt"

        if os.path.exists(fst_file):
            os.remove(fst_file)

        outfile = open(fst_file, 'a')
        outfile.write("pop1" + "\t" + "pop2" + "\t" + "chromosome" + "\t" +
                      "window_pos_1" + "\t" + "window_pos_2" + "\t" + "avg_" +
                      args.fst_type + "_fst" + "\t" + "no_snps" + "\n")

        # keep track of chrosomes with no fst data
        chroms_with_no_data = []

        if aggregate:  #put winsizes back together for each population to make final_window_size

            for chromosome in chrom_list:

                # logic to accodomate cases where pi/dxy have stats for a chromosome, but fst does not
                chromosome_has_data = True

                # if there are no valid fst estimates, set chromosome_has_data = False
                try:
                    outfst = outgrouped.get_group(
                        ("fst", chromosome)).reset_index(
                            drop=True)  #get this statistic, this chrom only
                except KeyError:
                    chroms_with_no_data.append(chromosome)
                    chromosome_has_data = False

                    pass

                if chromosome_has_data:
                    outfst.drop([0], axis=1, inplace=True)  #get rid of "fst"
                    outsorted = pixy.core.aggregate_output(
                        outfst, stat, chromosome, window_size, args.fst_type)
                    outsorted = outsorted.iloc[:, :
                                               -3]  #drop components (for now)
                    outsorted.to_csv(outfile,
                                     sep="\t",
                                     mode='a',
                                     header=False,
                                     index=False,
                                     na_rep='NA')  #write

        else:
            for chromosome in chrom_list:

                # logic to accodomate cases where pi/dxy have stats for a chromosome, but fst does not
                chromosome_has_data = True

                # if there are no valid fst estimates, set chromosome_has_data = False
                try:
                    outfst = outgrouped.get_group(
                        ("fst", chromosome)).reset_index(
                            drop=True)  #get this statistic, this chrom only
                except KeyError:
                    chroms_with_no_data.append(chromosome)
                    chromosome_has_data = False
                    pass

                if chromosome_has_data:
                    outfst.drop([0], axis=1, inplace=True)  #get rid of "fst"
                    outsorted = outfst.sort_values([4])  #sort by position
                    # make sure sites (but not components like pi/dxy)
                    cols = [7]
                    outsorted[cols] = outsorted[cols].astype('Int64')
                    outsorted = outsorted.iloc[:, :
                                               -3]  #drop components (for now)
                    outsorted.to_csv(outfile,
                                     sep="\t",
                                     mode='a',
                                     header=False,
                                     index=False,
                                     na_rep='NA')

        outfile.close()

        if len(chroms_with_no_data) >= 1:
            print(
                '\n[pixy] NOTE: ' +
                'The following chromosomes/scaffolds did not have sufficient data to estimate FST: '
                + ', '.join(chroms_with_no_data))

    # remove the temp file(s)
    if (args.keep_temp_file is not True):
        os.remove(temp_file)

    # confirm output was generated successfully
    outfolder_files = [
        f for f in os.listdir(output_folder)
        if os.path.isfile(os.path.join(output_folder, f))
    ]

    r = re.compile(".*_dxy.*|.*_pi.*|.*_fst.*")
    output_files = list(filter(r.match, outfolder_files))

    r = re.compile("pixy_tmpfile.*")
    leftover_tmp_files = list(filter(r.match, outfolder_files))

    if len(output_files) == 0:
        print(
            '\n[pixy] WARNING: pixy failed to write any output files. Your VCF may not contain valid genotype data, or it was removed via filtering using the specified sites/bed file (if any).'
        )

    # print completion message
    end_time = time.time()
    print("\n[pixy] All calculations complete at " +
          time.strftime("%H:%M:%S on %Y-%m-%d", time.localtime(end_time)))
    total_time = (time.time() - start_time)
    print("[pixy] Time elapsed: " +
          time.strftime("%H:%M:%S", time.gmtime(total_time)))
    print("[pixy] Output files written to: " + output_folder)

    if len(leftover_tmp_files) > 0:
        print("\n[pixy] NOTE: There are pixy temp files in " + output_folder)
        print(
            "[pixy] If these are not actively being used (e.g. by another running pixy process), they can be safely deleted."
        )

    print(
        "\n[pixy] If you use pixy in your research, please cite the following paper:\n[pixy] "
        + citation)

    # restore output
    if args.silent:
        sys.stdout = sys.__stdout__