Exemplo n.º 1
0
def run_func(target_func, func_args, split_args, core_nums=2):
    """多线程运算

    Parameters
    -----------
    target_func: func
        待运行的函数。需要分配到不同进程的参数必须放在该函数参数列表的最前面,即:
        target_func(split_args, func_args)
    func_args: dict
        被传入到运行函数中
    split_args: two-dimensioned array N*K
        参数列表会平均分配到不同的进程中去。N代表参数个数,K代表每个参数下元素数量。
    core_nums: int
        创建进程的数量
    """
    s_args = np.array_split(split_args, core_nums, axis=1)
    p = Pool(core_nums)
    for i in range(core_nums):
        print("create process %s" % i)
        p.apply_async(target_func, args=tuple(s_args[i]), kwds=func_args,
                      callback=lambda x: print(x), error_callback=lambda x: print(x))
    p.close()
    p.join()
    print("calculation has finished!")
    
 def multi_align_tr(self, imstack, TrM, nsz, shx, shy, stfolder, sfn,
                    nCORES, fnames, ext):
     if not sfn in os.listdir(stfolder):
         os.makedirs(os.path.join(stfolder, sfn))
         print('directory created')
     pool = Pool(nCORES)
     print('applying transformations with', nCORES,
           'processes in parallel ')
     results = []
     for i in range(len(imstack)):
         #            results.append(transform(imstack[i],TrM,nsz,shx,shy,stfolder,sfn,fnames,i,ext,))
         results.append(
             pool.apply_async(transform, (
                 imstack[i],
                 TrM,
                 nsz,
                 shx,
                 shy,
                 stfolder,
                 sfn,
                 fnames,
                 i,
                 ext,
             )))
         self.loading.progress2['value'] += 1
         self.update()
     pool.close()
     pool.join()
     print('successfully transformed all the images in the stack')
     return results
Exemplo n.º 3
0
    def extract_adj_noun_async(self):
        startTime = time.time()
        counters = []
        print "running on {} processors".format(WORKERS)
        pool = Pool(processes=WORKERS
                    )  #,initargs=(sent_locker,lock, sentence_counter))
        # adj_noun_dict = pool.map(self.__extract_patterns_from_file, self.data_wrapper.ngrams_files)
        results = [
            pool.apply_async(self.__extract_patterns_from_file, (file, ))
            for file in self.data_wrapper.ngrams_files
        ]

        # for res in results:
        #     dict_res = res.get()
        #     res_counter = Counter(dict_res)
        #     counters.append(res_counter)
        counters = [Counter(x.get()) for x in results]
        print "starting to sum counters from {} dictionaries".format(
            len(counters))
        self.adj_noun_to_count = sum(counters, Counter())
        print "done summing counters"
        # pool.close()
        # pool.join()
        total_time = time.time() - startTime
        print "extract_adj_noun_async running time: {}".format(total_time)
def plotting_demon(plotting_queue, multicore):
    print("Starting plotting daemon...", end=" ")
    pool = Pool(processes=multicore)
    print("Done!")
    while True:
        rec = plotting_queue.get()
        if rec == 0:
            break
        func, arguments = rec
        if isinstance(arguments, tuple):
            pool.apply_async(func, *arguments)
        else:
            pool.apply_async(func, arguments)

    pool.close()
    pool.join()
    del pool
    print("Plotting daemon terminated.")
    exit(0)
Exemplo n.º 5
0
 def inner(*args):
     pool = Pool(processes=1)
     res = pool.apply_async(f,args)
     try:
         v = res.get(timeout=sec)
     except Exception as inst:
         print(inst)
         v = None
     finally:
         pool.terminate()
         return v
Exemplo n.º 6
0
 def mp_pooler(self,nCORES,func,*args):
     pool=Pool(nCORES)
     print('computing with',nCORES,'processes in parallel')
     results=[]
     for i in range(len(args[0])-1):
         results.append(pool.apply_async(func,(args[0][i],args[0][i+1],*args[1:],i,)))
         self.loading.progress2['value']+=1
         self.update()
     pool.close()
     pool.join()
     return results       
Exemplo n.º 7
0
def download_image_thread(location_q, image_q, MAX_DL_THREADS=10):
    print("Running Download Image Thread.")

    max_processes = MAX_DL_THREADS
    print("Creating a thread pool of size {} for downloading images...".format(max_processes))
    pool = Pool(processes=max_processes)
    # Allow us to have n processes runnning, and n processes scheduled to run
    # TODO: Manager is not necessary here, but is used to get around the fact
    # that thread-safe objects cannot be passed by reference, they must be
    # inheretence. A more lightweight solution should be found
    workers = Manager().Semaphore(max_processes*2)

    def async_download(location):
        image = download_image(location)
        image_q.put((location, image), True)
        workers.release()

    while True:
        location = location_q.get(True)
        workers.acquire()
        pool.apply_async(async_download, (location,))
Exemplo n.º 8
0
    def prime_calculate(self):
        break_points = []  # List that will have start and stopping points
        for i in range(cores):  # Creates start and stopping points based on length of range_finish
            break_points.append(
                {"start": int(math.ceil(((self.maximum_prime + 1) + 0.0) / cores * i)),
                 "stop": int(math.ceil(((self.maximum_prime + 1) + 0.0) / cores * (i + 1)))})

        p = Pool(cores)  # Number of processes to create.
        for i in break_points:  # Cycles though the breakpoints list created above.
            a = p.apply_async(self.prime_calculator, kwds=i, args=tuple(),
                              callback=self.update_num)  # This will start the separate processes.
        p.close()  # Prevents any more processes being started
        p.join()  # Waits for worker process to end
Exemplo n.º 9
0
def monteCarlo(agent, maxDepth=3, trials=12, frequency=10):
    PROCESSES = 4
    model = VelocityModel(
        regressionModel=joblib.load('models/gradient-m.model'),
        frequency=frequency)
    actions = np.array(agent.getActions())
    initialState, isTerminal = agent.getState(), 0

    jobs = [None] * len(actions) * trials
    while bool(isTerminal) is False:
        initialState = agent.getState()
        qs = {i: [] for i in actions}

        for index, a in enumerate(np.repeat(actions, trials)):
            virtualAgent, isTerminal = RLAgent(
                'virtual',
                alternativeModel=model,
                decisionFrequency=math.inf,
                maxDepth=maxDepth,
                initialState=initialState), False
            virtualAgent.setReward(reward)
            virtualAgent.goal = agent.getGoal()
            virtualAgent.goalMargins = agent.getGoalMargins()

            virtualAgent.setRl(
                partial(monteCarloSearch,
                        actions=getRandomActions(a, actions, maxDepth)))
            jobs[index] = virtualAgent

        pool = Pool(8)
        results = [pool.apply_async(job.run) for job in jobs]
        for result in results:
            action, score = result.get()
            qs[action].append(score)

        pool.close()
        pool.join()

        yield actions[np.argmax([np.average(qs[a]) for a in actions])]
        r, nextState, isTerminal = (yield)

        f = 1 / (nextState.lastUpdate - initialState.lastUpdate)
        # correct for deviations from desired freq.
        model.frequency = f

        agent.logger.info(f)

        yield
Exemplo n.º 10
0
    def map(self, fn, lazy=True, batched=False, num_workers=0):
        """
        Performs specific function on the dataset to transform and update every sample.

        Args:
            fn (callable): Transformations to be performed. It receives single
                sample as argument if batched is False. Else it receives all examples.
            lazy (bool, optional): If True, transformations would be delayed and
                performed on demand. Otherwise, transforms all samples at once. Note that 
                if `fn` is stochastic, `lazy` should be True or you will get the same
                result on all epochs. Defaults to False.
            batched(bool, optional): If True, transformations would take all examples as 
                input and return a collection of transformed examples. Note that if set 
                True, `lazy` option would be ignored. Defaults to False.
            num_workers(int, optional): Number of processes for multiprocessing. If 
                set to 0, it doesn't use multiprocessing. Note that if set to positive
                value, `lazy` option would be ignored. Defaults to 0.
        """

        assert num_workers >= 0, "num_workers should be a non-negative value"
        if num_workers > 1:
            shards = [
                self._shard(
                    num_shards=num_workers, index=index, contiguous=True)
                for index in range(num_workers)
            ]
            kwds_per_shard = [
                dict(
                    self=shards[rank], fn=fn, lazy=False, batched=batched)
                for rank in range(num_workers)
            ]
            pool = Pool(num_workers, initargs=(RLock(), ))
            results = [
                pool.apply_async(
                    self.__class__._map, kwds=kwds) for kwds in kwds_per_shard
            ]
            transformed_shards = [r.get() for r in results]
            pool.close()
            pool.join()
            self.new_data = []
            for i in range(num_workers):
                self.new_data += transformed_shards[i].new_data
            return self
        else:
            return self._map(fn, lazy=lazy, batched=batched)
Exemplo n.º 11
0
class ProcessPoolExecutor(Executor):
    """Process Pool Executor"""
    def __init__(self):
        super(ProcessPoolExecutor, self).__init__()
        import os
        from multiprocess import Pool
        self.pool = Pool(os.cpu_count() or 1)

    def submit(self, func, *args, **kwargs):
        from concurrent.futures import Future
        fut = Future()
        self.tasks[fut] = self.pool.apply_async(func, args, kwargs,
                                                fut.set_result,
                                                fut.set_exception)
        fut.add_done_callback(self.tasks.pop)
        return fut

    def shutdown(self, wait=True):
        super(ProcessPoolExecutor, self).shutdown(wait)
        self.pool.terminate()
        self.pool.join()
Exemplo n.º 12
0
class ProcessPoolExecutor(Executor):
    """Process Pool Executor"""
    def __init__(self):
        super(ProcessPoolExecutor, self).__init__()
        import os
        from multiprocess import Pool
        self.pool = Pool(os.cpu_count() or 1)

    def submit(self, func, *args, **kwargs):
        from concurrent.futures import Future
        fut = Future()
        self.tasks[fut] = self.pool.apply_async(
            func, args, kwargs, fut.set_result, fut.set_exception
        )
        fut.add_done_callback(self.tasks.pop)
        return fut

    def shutdown(self, wait=True):
        super(ProcessPoolExecutor, self).shutdown(wait)
        self.pool.terminate()
        self.pool.join()
Exemplo n.º 13
0
    def filter(self, fn, num_workers=0):
        """
        Filters samples by the filter function and uses the filtered data to
        update this dataset.

        Args:
            fn (callable): A filter function that takes a sample as input and
                returns a boolean. Samples that return False would be discarded.
            num_workers(int, optional): Number of processes for multiprocessing. If 
                set to 0, it doesn't use multiprocessing. Defaults to `0`.
        """
        assert num_workers >= 0, "num_workers should be a non-negative value"
        if num_workers > 1:
            shards = [
                self._shard(
                    num_shards=num_workers, index=index, contiguous=True)
                for index in range(num_workers)
            ]
            kwds_per_shard = [
                dict(
                    self=shards[rank], fn=fn) for rank in range(num_workers)
            ]
            pool = Pool(num_workers, initargs=(RLock(), ))

            results = [
                pool.apply_async(
                    self.__class__._filter, kwds=kwds)
                for kwds in kwds_per_shard
            ]
            transformed_shards = [r.get() for r in results]

            pool.close()
            pool.join()
            self.new_data = []
            for i in range(num_workers):
                self.new_data += transformed_shards[i].new_data
            return self
        else:
            return self._filter(fn)
Exemplo n.º 14
0
    def prime_calculate(self):
        break_points = []  # List that will have start and stopping points
        for i in range(
                cores
        ):  # Creates start and stopping points based on length of range_finish
            break_points.append({
                "start":
                int(math.ceil(((self.maximum_prime + 1) + 0.0) / cores * i)),
                "stop":
                int(
                    math.ceil(
                        ((self.maximum_prime + 1) + 0.0) / cores * (i + 1)))
            })

        p = Pool(cores)  # Number of processes to create.
        for i in break_points:  # Cycles though the breakpoints list created above.
            a = p.apply_async(self.prime_calculator,
                              kwds=i,
                              args=tuple(),
                              callback=self.update_num
                              )  # This will start the separate processes.
        p.close()  # Prevents any more processes being started
        p.join()  # Waits for worker process to end
# multiprocess_pool.py
#!/usr/bin/env python
# _*_coding:utf-8_*_

# import multiprocess
from multiprocess import Pool
import os, time, random


def long_time_task(name):
    print('Run task %s (%s) ...' % (name, os.getpgid()))
    start = time.time()
    time.sleep(random.random() * 3)
    end = time.time()
    print('Task %s runs %0.2f seconds.' % (name, (end - start)))


if __name__ == '__main__':
    print('Parent process %s.' % os.getpid())
    p = Pool(4)
    for i in range(5):
        p.apply_async(long_time_task, args=(i, ))
    print('Waiting for all subprocesses done...')
    p.close()
    p.join()
    print('All subprocesses done.')
Exemplo n.º 16
0
class Scheduler(object):
    '''
    Scheduler
    '''
    def __init__(self, storage, threads):
        # Manager for concurrency
        self.manager = Manager()

        # System storage
        self.storage = storage

        # Queues
        self.high_access = self.manager.list([])
        self.normal_access = self.manager.list([])
        self._pool = Pool(processes=threads)

        # Operations
        self.operation_table = self.manager.dict()

    def add_operation(self, dataset_id, prio, map_operation, reduce_operation, return_address=None, write=False, read=True):
        '''
        Add operation to a dataset
        '''

        # Create operation object
        operation = Operation(map_operation, reduce_operation, return_address, write, read)

        # Add the operation to queue
        if dataset_id in self.operation_table:
            # Adding operation to the list of operations for the current dataset
            ## Creating temporary list, since it is not possible to append to a dictionary manager
            temperary_operations = self.operation_table[dataset_id]
            temperary_operations.append(operation)
            self.operation_table[dataset_id] = temperary_operations
        else:
            self.operation_table[dataset_id] = [operation]

        # Add data block to scheduler
        if prio == Priority.high:
            if dataset_id not in self.high_access:
                self.high_access.append(dataset_id)
                if dataset_id in self.normal_access:
                    self.normal_access.remove(dataset_id)

        elif prio == Priority.normal:
            if dataset_id not in self.normal_access and dataset_id not in self.high_access:
                self.normal_access.append(dataset_id)

    def _run_queue(self, dataset_id, debug=False):
        '''
        Run the queue of operations for a given dataset
        '''
        # Create data queue and a storage reading process
        if debug:
            print('~ Request data blocks from reading process')

        data_queue = self.manager.Queue()
        self.storage.read_data(dataset_id, data_queue)

        # Amount of operations
        operations = len(self.operation_table[dataset_id])

        # Amount of data-blocks
        data_blocks = self.storage.get_size(dataset_id)

        # Create a result list to each operation
        results = []
        for i in range(operations):
            results.append([])

        # Execute map-operation on the data queue
        for i in range(data_blocks):
            try:
                # Fetch data block from data queue
                data_block = data_queue.get(timeout=3)

                print('- Performing operations on block: ' + str(i) + ', dataset: ' + dataset_id)

                # Perform the operations on fetched data block
                op_index = 0
                for operation in self.operation_table[dataset_id]:
                    if debug:
                        print('~ Performing map operation (' + str(operation) + ') on block ' + str(i))

                    results[op_index].append(operation.map(data_block))
                    op_index += 1
            except:
                print('! Timeouted waiting for data')

        # Execute the reduce-operation
        op_index = 0
        for operation in self.operation_table[dataset_id]:
            if debug:
                print('~ Performing reduce operation (' + str(operation) + ')')

            operation.reduce(results[op_index])
            op_index += 1

        # Clear the operation table for this block
        if debug:
            print('~ Clearing operations for '+ str(dataset_id))

        self.operation_table[dataset_id] = []

        # Remove the operation meta data for the dataset
        if operations > 0:
            if debug:
                print('~ Removing the dataset '+ str(dataset_id) + ' from operation table')

            del self.operation_table[dataset_id]

    def schedule(self, debug=False):
        '''
        Schedule the queued operations
        '''
        if debug:
            print('~ Initiating reading process')

        reading_process = Process(target=self.storage.reader)
        reading_process.start()

        while True:
            if debug:
                print()
                print('/ High priority queue is ' + str(self.high_access))
                print('/ Normal priority queue is ' + str(self.normal_access))
                print()

            if self.high_access:
                self._pool.apply_async(self._run_queue(self.high_access.pop(0)))
            elif self.normal_access:
                self._pool.apply_async(self._run_queue(self.normal_access.pop(0)))
            else:
                time.sleep(0.5)
Exemplo n.º 17
0
def test():
    print('cpuCount() = %d\n' % cpuCount())
    
    #
    # Create pool
    #
    
    PROCESSES = 4
    print('Creating pool with %d processes\n' % PROCESSES)
    pool = Pool(PROCESSES)    

    #
    # Tests
    #

    TASKS = [(mul, (i, 7)) for i in range(10)] + \
            [(plus, (i, 8)) for i in range(10)]

    results = [pool.apply_async(calculate, t) for t in TASKS]
    imap_it = pool.imap(calculatestar, TASKS)
    imap_unordered_it = pool.imap_unordered(calculatestar, TASKS)

    print('Ordered results using pool.apply_async():')
    for r in results:
        print('\t', r.get())
    print()

    print('Ordered results using pool.imap():')
    for x in imap_it:
        print('\t', x)
    print()

    print('Unordered results using pool.imap_unordered():')
    for x in imap_unordered_it:
        print('\t', x)
    print()

    print('Ordered results using pool.map() --- will block till complete:')
    for x in pool.map(calculatestar, TASKS):
        print('\t', x)
    print()

    #
    # Simple benchmarks
    #

    N = 100000
    print('def pow3(x): return x**3')
    
    t = time.time()
    A = list(map(pow3, xrange(N)))
    print('\tmap(pow3, xrange(%d)):\n\t\t%s seconds' % \
          (N, time.time() - t))
    
    t = time.time()
    B = pool.map(pow3, xrange(N))
    print('\tpool.map(pow3, xrange(%d)):\n\t\t%s seconds' % \
          (N, time.time() - t))

    t = time.time()
    C = list(pool.imap(pow3, xrange(N), chunksize=N//8))
    print('\tlist(pool.imap(pow3, xrange(%d), chunksize=%d)):\n\t\t%s' \
          ' seconds' % (N, N//8, time.time() - t))
    
    assert A == B == C, (len(A), len(B), len(C))
    print()
    
    L = [None] * 1000000
    print('def noop(x): pass')
    print('L = [None] * 1000000')
    
    t = time.time()
    A = list(map(noop, L))
    print('\tmap(noop, L):\n\t\t%s seconds' % \
          (time.time() - t))
    
    t = time.time()
    B = pool.map(noop, L)
    print('\tpool.map(noop, L):\n\t\t%s seconds' % \
          (time.time() - t))

    t = time.time()
    C = list(pool.imap(noop, L, chunksize=len(L)//8))
    print('\tlist(pool.imap(noop, L, chunksize=%d)):\n\t\t%s seconds' % \
          (len(L)//8, time.time() - t))

    assert A == B == C, (len(A), len(B), len(C))
    print()

    del A, B, C, L

    #
    # Test error handling
    #

    print('Testing error handling:')

    try:
        print(pool.apply(f, (5,)))
    except ZeroDivisionError:
        print('\tGot ZeroDivisionError as expected from pool.apply()')
    else:
        raise AssertionError('expected ZeroDivisionError')

    try:
        print(pool.map(f, range(10)))
    except ZeroDivisionError:
        print('\tGot ZeroDivisionError as expected from pool.map()')
    else:
        raise AssertionError('expected ZeroDivisionError')
            
    try:
        print(list(pool.imap(f, range(10))))
    except ZeroDivisionError:
        print('\tGot ZeroDivisionError as expected from list(pool.imap())')
    else:
        raise AssertionError('expected ZeroDivisionError')

    it = pool.imap(f, range(10))
    for i in range(10):
        try:
            x = it.next()
        except ZeroDivisionError:
            if i == 5:
                pass
        except StopIteration:
            break
        else:
            if i == 5:
                raise AssertionError('expected ZeroDivisionError')
            
    assert i == 9
    print('\tGot ZeroDivisionError as expected from IMapIterator.next()')
    print()
    
    #
    # Testing timeouts
    #
    
    print('Testing ApplyResult.get() with timeout:', end='')
    res = pool.apply_async(calculate, TASKS[0])
    while 1:
        sys.stdout.flush()
        try:
            sys.stdout.write('\n\t%s' % res.get(0.02))
            break
        except TimeoutError:
            sys.stdout.write('.')
    print()
    print()

    print('Testing IMapIterator.next() with timeout:', end='')
    it = pool.imap(calculatestar, TASKS)
    while 1:
        sys.stdout.flush()
        try:
            sys.stdout.write('\n\t%s' % it.next(0.02))
        except StopIteration:
            break
        except TimeoutError:
            sys.stdout.write('.')
    print()
    print()
            
    #
    # Testing callback
    #

    print('Testing callback:')
    
    A = []
    B = [56, 0, 1, 8, 27, 64, 125, 216, 343, 512, 729]
        
    r = pool.apply_async(mul, (7, 8), callback=A.append)
    r.wait()

    r = pool.map_async(pow3, range(10), callback=A.extend)
    r.wait()

    if A == B:
        print('\tcallbacks succeeded\n')
    else:
        print('\t*** callbacks failed\n\t\t%s != %s\n' % (A, B))
    
    #
    # Check there are no outstanding tasks
    #
    
    assert not pool._cache, 'cache = %r' % pool._cache

    #
    # Check close() methods
    #

    print('Testing close():')

    for worker in pool._pool:
        assert worker.is_alive()

    result = pool.apply_async(time.sleep, [0.5])
    pool.close()
    pool.join()

    assert result.get() is None

    for worker in pool._pool:
        assert not worker.is_alive()

    print('\tclose() succeeded\n')

    #
    # Check terminate() method
    #

    print('Testing terminate():')

    pool = Pool(2)
    ignore = pool.apply(pow3, [2])
    results = [pool.apply_async(time.sleep, [10]) for i in range(10)]
    pool.terminate()
    pool.join()

    for worker in pool._pool:
        assert not worker.is_alive()

    print('\tterminate() succeeded\n')

    #
    # Check garbage collection
    #

    print('Testing garbage collection:')

    pool = Pool(2)
    processes = pool._pool
    
    ignore = pool.apply(pow3, [2])
    results = [pool.apply_async(time.sleep, [10]) for i in range(10)]

    del results, pool

    time.sleep(0.2)
    
    for worker in processes:
        assert not worker.is_alive()

    print('\tgarbage collection succeeded\n')
Exemplo n.º 18
0
class GroupCheckerGui(BaseWidget):
    def __init__(self):
        super(GroupCheckerGui, self).__init__('Group Checker')

        self._group_name = ControlText('Group Name', CONFIG['group_name'])
        self._group_name.enabled = False
        self._allowed_tags = UnicodeControlList(
            'Allowed Tags',
            plusFunction=self.__add_tag_action,
            minusFunction=self.__remove_tag_action)
        self.allowed_tags = GuiList(
            CONFIG['white_filters']['SubstringFilter']['substrings'],
            self._allowed_tags)

        self._allowed_ids = ControlList('Allowed Ids',
                                        plusFunction=self.__add_id_action,
                                        minusFunction=self.__remove_id_action)
        self.allowed_ids = GuiList(
            CONFIG['white_filters']['SignerFilter']['ids'], self._allowed_ids)

        self._bad_posts = ControlCheckBoxList('Bad posts')
        self._bad_posts._form.listWidget.itemDoubleClicked.connect(
            self.__show_link_action)

        self._remove_button = ControlButton('Remove')
        self._remove_button.value = self.__remove_action

        self._show_button = ControlButton('Show bad posts')
        self._show_button.value = self.__show_bad_post_action

        self.pool = Pool(processes=1)
        self.bad_posts = []

        self._formset = [('', '_group_name', ''),
                         ('', '_allowed_tags', '_allowed_ids', ''), '',
                         ('', '_bad_posts', ''),
                         ('', '_remove_button', '_show_button', ''), '']

    def __add_tag_action(self):
        win = PopUpGetText('tag', self.allowed_tags)
        win.show()

    def __remove_tag_action(self):
        self.allowed_tags.remove(self._allowed_tags.mouseSelectedRowIndex)

    def __add_id_action(self):
        win = PopUpGetText('id', self.allowed_ids)
        win.show()

    def __remove_id_action(self):
        self.allowed_ids.remove(self._allowed_ids.mouseSelectedRowIndex)

    def __show_bad_post_action(self):
        def callback(posts):
            self.bad_posts = posts
            self._bad_posts.value = [
                (GroupBot.get_link_from_post(post, CONFIG['group_name']), True)
                for post in posts
            ]
            self._show_button.enabled = True

        def run_bot():
            bot = create_bot_from_config()
            return bot.get_bad_posts()

        self._show_button.enabled = False
        self.pool.apply_async(run_bot, callback=callback)

    def __show_link_action(self, link):
        webbrowser.open(link.text())

    def __remove_action(self):
        checked_posts = [
            self.bad_posts[idx] for idx in self._bad_posts.checkedIndexes
        ]
        bot = create_bot_from_config()
        bot.remove_posts(checked_posts)
Exemplo n.º 19
0
def main():
    start_time = time.time()
    parser = get_args()
    if not sys.argv[1:]:
        parser.print_help(file=sys.stderr)
        sys.exit(2)

    # Parse arguments.
    args = parser.parse_args()
    if args.whitelist:
        whitelist = preprocessing.parse_whitelist_csv(
            args.whitelist, args.cb_last - args.cb_first + 1)
    else:
        whitelist = None

    # Load TAGs/ABs.
    ab_map = preprocessing.parse_tags_csv(args.tags)
    ab_map = preprocessing.check_tags(ab_map, args.max_error)
    # Get reads length. So far, there is no validation for Read2.
    read1_length = preprocessing.get_read_length(args.read1_path)
    read2_length = preprocessing.get_read_length(args.read2_path)
    # Check Read1 length against CELL and UMI barcodes length.
    (barcode_slice, umi_slice,
     barcode_umi_length) = preprocessing.check_barcodes_lengths(
         read1_length, args.cb_first, args.cb_last, args.umi_first,
         args.umi_last)

    if args.first_n:
        n_lines = args.first_n * 4
    else:
        n_lines = preprocessing.get_n_lines(args.read1_path)
    n_reads = int(n_lines / 4)
    n_threads = args.n_threads

    print('Started mapping')
    #Run with one process
    if n_threads <= 1 or n_reads < 1000001:
        print('CITE-seq-Count is running with one core.')
        (final_results, merged_no_match) = processing.map_reads(
            read1_path=args.read1_path,
            read2_path=args.read2_path,
            tags=ab_map,
            barcode_slice=barcode_slice,
            umi_slice=umi_slice,
            indexes=[0, n_reads],
            whitelist=whitelist,
            debug=args.debug,
            start_trim=args.start_trim,
            maximum_distance=args.max_error)
        print('Mapping done')
        umis_per_cell = Counter()
        reads_per_cell = Counter()
        for cell_barcode, counts in final_results.items():
            umis_per_cell[cell_barcode] = sum(
                [len(counts[UMI]) for UMI in counts if UMI != 'unmapped'])
            reads_per_cell[cell_barcode] = sum([
                sum(counts[UMI].values()) for UMI in counts
                if UMI != 'unmapped'
            ])
    else:
        # Run with multiple processes
        print('CITE-seq-Count is running with {} cores.'.format(n_threads))
        p = Pool(processes=n_threads)
        chunk_indexes = preprocessing.chunk_reads(n_reads, n_threads)
        parallel_results = []

        for indexes in chunk_indexes:
            p.apply_async(processing.map_reads,
                          args=(args.read1_path, args.read2_path, ab_map,
                                barcode_slice, umi_slice, indexes, whitelist,
                                args.debug, args.start_trim, args.max_error),
                          callback=parallel_results.append,
                          error_callback=sys.stderr)
        p.close()
        p.join()
        print('Mapping done')
        print('Merging results')
        (final_results, umis_per_cell, reads_per_cell,
         merged_no_match) = processing.merge_results(
             parallel_results=parallel_results)
        del (parallel_results)

    # Correct cell barcodes
    (final_results, umis_per_cell, bcs_corrected) = processing.correct_cells(
        final_results=final_results,
        umis_per_cell=umis_per_cell,
        expected_cells=args.expected_cells,
        collapsing_threshold=args.bc_threshold)

    # Correct umi barcodes
    (final_results, umis_corrected) = processing.correct_umis(
        final_results=final_results, collapsing_threshold=args.umi_threshold)

    ordered_tags_map = OrderedDict()
    for i, tag in enumerate(ab_map.values()):
        ordered_tags_map[tag] = i
    ordered_tags_map['unmapped'] = i + 1

    # Sort cells by number of mapped umis
    if not whitelist:
        top_cells_tuple = umis_per_cell.most_common(args.expected_cells)
        top_cells = set([pair[0] for pair in top_cells_tuple])
    else:
        top_cells = whitelist
        # Add potential missing cell barcodes.
        for missing_cell in whitelist:
            if missing_cell in final_results:
                continue
            else:
                final_results[missing_cell] = dict()
                for TAG in ordered_tags_map:
                    final_results[missing_cell][TAG] = 0
                top_cells.add(missing_cell)

    (umi_results_matrix,
     read_results_matrix) = processing.generate_sparse_matrices(
         final_results=final_results,
         ordered_tags_map=ordered_tags_map,
         top_cells=top_cells)
    io.write_to_files(sparse_matrix=umi_results_matrix,
                      top_cells=top_cells,
                      ordered_tags_map=ordered_tags_map,
                      data_type='umi',
                      outfolder=args.outfolder)
    io.write_to_files(sparse_matrix=read_results_matrix,
                      top_cells=top_cells,
                      ordered_tags_map=ordered_tags_map,
                      data_type='read',
                      outfolder=args.outfolder)

    top_unmapped = merged_no_match.most_common(args.unknowns_top)
    with open(os.path.join(args.outfolder, args.unmapped_file),
              'w') as unknown_file:
        unknown_file.write('tag,count\n')
        for element in top_unmapped:
            unknown_file.write('{},{}\n'.format(element[0], element[1]))
    create_report(n_reads=n_reads,
                  reads_per_cell=reads_per_cell,
                  no_match=merged_no_match,
                  version=version,
                  start_time=start_time,
                  ordered_tags_map=ordered_tags_map,
                  umis_corrected=umis_corrected,
                  bcs_corrected=bcs_corrected,
                  args=args)
    if args.dense:
        print('Writing dense format output')
        io.write_dense(sparse_matrix=umi_results_matrix,
                       index=list(ordered_tags_map.keys()),
                       columns=top_cells,
                       file_path=os.path.join(args.outfolder,
                                              'dense_umis.tsv'))
Exemplo n.º 20
0
def capacity_estimation(n,
                        threshold=0,
                        runs=1000,
                        num_averages=10,
                        num_updates=1,
                        asynchronous=False,
                        distinct_memories=True,
                        parallel=True,
                        processes=None,
                        use_SER=True):
    """
    Estimates the number m_max of storeable patterns for all n given as array, 
    such that all m smaller or equal to m_max have an error rate which is exactly 0.
    Repeats the estimation of m_max with run samples of the error rate and repeats this num_averages times
    Returns mean of m_max and standard deviation as np.array 
    """
    m_max = -1 * np.zeros((len(n), num_averages))

    # loop over all n:
    for i in range(len(n)):
        print(n[i])

        # Parallel method
        if parallel:

            # Repeat estimation of at_least_one_fail runs times for averaging in a parallelized fashion
            pool = Pool(processes=processes)
            results = [
                pool.apply_async(estimate_m_max,
                                 args=(n[i], runs, asynchronous, num_updates,
                                       distinct_memories, threshold))
                for _ in range(num_averages)
            ]
            output = [process.get() for process in results]
            m_max[i, :] = np.array(output, dtype=int)
            pool.terminate()

        # Standard method
        else:
            for k in range(num_averages):
                m = 0
                error_rate = 0
                while error_rate <= threshold:
                    m += 1
                    if use_SER:
                        error_rate, _ = test_error_rate(
                            n[i],
                            m,
                            runs=runs,
                            asynchronous=asynchronous,
                            num_updates=num_updates,
                            distinct_memories=distinct_memories,
                            parallel=parallel,
                            processes=processes)
                    else:
                        error_rate, _, _, _ = test_error_rate_attraction(
                            n,
                            m[i],
                            0,
                            N_test_samples_around_memory=1,
                            averages=1,
                            asynchronous=asynchronous,
                            num_updates=num_updates,
                            distinct_memories=True,
                            parallel=parallel,
                            processes=processes)
                m_max[i, k] = m - 1

    return np.mean(m_max, axis=1), np.std(m_max, axis=1)
Exemplo n.º 21
0
if __name__ == "__main__":
    # exp_list = [exp01, exp01]
    test_prefix_dir = "../../../../../../work/agents/Compare/AgentStepAsCkpt"
    # exp_ids = ['exp01', 'exp09', 'exp10', 'exp19']
    exp_ids = ['exp12']
    exp_infos = concat_names(test_prefix_dir, exp_ids, num=11)

    to_do_exp_infos = {}
    for i in exp_infos:
        to_do_exp_infos[i] = []

    for i in exp_infos:
        exps = exp_infos[i]
        for exp in exps:
            if os.path.exists(exp) and not os.path.exists(exp + "/logging_0"):
                to_do_exp_infos[i].append(exp)

    to_do_exp_list = to_do_exp_infos.values()
    lunzhuan_exp_list = lunzhun(to_do_exp_list)
    from pprint import pprint
    pprint(lunzhuan_exp_list)

    pool = Pool(3)
    # pool.map(run_exp, lunzhuan_exp_list)
    for para in lunzhuan_exp_list:
        pool.apply_async(run_exp, (para, ))
    print "What happened"
    pool.close()
    pool.join()
Exemplo n.º 22
0
def test_error_rate_attraction(n: int,
                               m: int,
                               noise_rate: float,
                               N_test_samples_around_memory=10,
                               averages=1000,
                               num_updates=1,
                               asynchronous=False,
                               distinct_memories=True,
                               parallel=True,
                               processes=None):
    """
    Estimates message and bit error rates for a given number of test samples in the vacinity of all memories.
    The test samples are produced by flipping every bit with probability noise_rate and the output is after a number 
    of updates is compared to the corresponding memory bit string.
    The whole procedure is repeated averages times, where new stored patterns are generated uniformly at random each time. 
    
    Args:   - n: input dimension
            - m: number of stored patterns
            - noise_rate: noise applied to memories to produce test samples
            - N_test_samples_around_memory: number of noisy test samples produced around each memory
            - averages: number of iterations for averaging, so repeating the whole precedure
            - num_updates: number of synchronous updates, so how often sign(W...) is applied
            - synchronous updates if True, else asynchronous

    Returns:  - mean of MessageErrorRate (MER) estimated for each repetition
            - mean of BitErrorRate (BER) estimated for each repetition
            - standard deviation of MER
            - standard deviation of BER
    """
    BitErrorRates = np.zeros(averages)
    MessageErrorRates = np.zeros(averages)

    # Repeat error rate estimation averages times, each time picking new, randomly chosen memories
    # Parallel method
    if parallel:

        # Repeat estimation of at_least_one_fail runs times for averaging in a parallelized fashion
        pool = Pool(processes=processes)
        results = [
            pool.apply_async(estimate_error_rates_in_parallel,
                             args=(n, m, distinct_memories,
                                   N_test_samples_around_memory, noise_rate,
                                   asynchronous, num_updates))
            for _ in range(averages)
        ]
        outputs = np.array([process.get() for process in results])
        MessageErrorRates = outputs[:, 0]
        BitErrorRates = outputs[:, 1]
        pool.terminate()

    # standard method
    else:
        for k in range(averages):
            W, x = new_weights(n, m, distinct_memories=distinct_memories)
            distances = np.zeros((m, N_test_samples_around_memory))

            for p in range(m):
                # print('memory : {}'.format(x[p]))
                for i_test in range(N_test_samples_around_memory):
                    # create a test sample by applying uniform noise to the corresponding memory
                    test_string = uniform_noise(x[p],
                                                noise_rate,
                                                binary01=False)
                    # update the test sample
                    if asynchronous:
                        updated = asynchronous_update(W,
                                                      test_string,
                                                      num_updates=num_updates)
                    else:
                        updated = synchronous_update(W,
                                                     test_string,
                                                     num_updates=num_updates)
                    # print('test: {}, upd: {}'.format(test_string,updated))
                    # Calculate the distance to the corresponding memory
                    distances[p, i_test] = hamming(updated, x[p]) * n

            # store the bit error rates
            BitErrorRates[k] = np.mean(distances / n)

            # store message error rates: if any bit is wrong, whole message is wrong!
            # Note difference to strict error definition in test_error_rate:
            # Here: For noise_rate=0, all memories are considered. If one out of all memories cannot be recalled exactly,
            #       the error rate of this round is 1/m
            # Strict:   If anywhere an error occures, assign failure to all. If one out of all memories cannot be recalled
            #           exactly, the error rate of this round is 1!
            # Use np.any(distances>0) to recover the strict definition!
            MessageErrorRates[k] = np.mean(distances > 0)

    return np.mean(MessageErrorRates), np.mean(BitErrorRates), np.std(
        MessageErrorRates), np.std(BitErrorRates)
Exemplo n.º 23
0
def test_error_rate_quantum_parallel(n: int,
                                     m: int,
                                     runs=1000,
                                     use_QI_backend=False,
                                     api=None,
                                     num_updates=1,
                                     gate_fusion=False,
                                     separate_outputs=True,
                                     output_std=True,
                                     output_probabilities=True,
                                     processes=None,
                                     logfile_name='./log.txt',
                                     errorlog_name='./errorlog.txt',
                                     num_runs=1,
                                     verbose=False,
                                     shots=0):
    """
    Estimate average error rate, when requiring that all stored patterns must be retrieved perfectly.
    In a parallelized fashion, new stored patterns are generated uniformly at random. 
    To run the HNN, a quantum version is applied.
    The Hebbian weighting matrix is calculated and used to estimate the gate parameters of the quantum circuit.
    
    Note: Do not use this function if OS is Windows! (Typically causes problems) 

    Args:   - n: input dimension
            - m: number of stored patterns
            - runs: number of iterations for average
            - separate_outputs: If true, simulation is done for each output qubit separately, 
                thus only n+1 qubits need to be simulated. Use this for local simulations.
                Else, the whole 2n-qubit circuit is simulated. Use this for Quantum Inspire, as long as n<14

            If use_QI_backend=True:
            - api to connect to QuantumInspire
            - num_updates: number of synchronous updates, so how often the updating scheme is applied

            If use_QI_backend=False:
            - gate_fusion parameter for ProjectQ Simulator, may or may not speed up simulation
            - shots = number of experiments. If 0, true probabilities are considered, 
            else experiments are simulated using np.random.choice, based on true probabilities.

    Returns:  - Average rate of how often at least one stored pattern cannot be retrieved after num_updates updates
            - Standard deviation if output_std = True
            - runs x m-array with probabilities of all estimated most likely outcomes (only if output_probabilities=True)
    """
    # Seems to not work under Windows, thus abort then
    #%%
    assert (
        platform != 'win32'
    ), 'Parallelization with multiprocess package does not work with Windows'

    #at_least_one_fail = np.zeros(runs,dtype=bool)
    probabilities = -np.zeros((runs, m))

    # Repeat procedure runs times for averaging in a parallelized fashion
    pool = Pool(processes=processes)
    results = [
        pool.apply_async(parallel_method,
                         args=(n, m, separate_outputs, verbose, api,
                               logfile_name, errorlog_name, use_QI_backend,
                               gate_fusion, num_runs, new_weights, shots))
        for _ in range(runs)
    ]
    output = [process.get() for process in results]
    output = np.array(output)

    # Format the outputs and store in proper arrays
    at_least_one_fail = np.array(output[:, 0], dtype=bool)
    for k in range(runs):
        for p in range(m):
            probabilities[k, p] = output[k, 1][p]

    # go to next line in log file
    if verbose:
        with open(logfile_name, 'a') as f:
            f.write('\n')
    # Return results according the demanded outputs
    if output_std:
        if output_probabilities:
            return np.average(at_least_one_fail), np.std(
                at_least_one_fail), probabilities
        else:
            return np.average(at_least_one_fail), np.std(at_least_one_fail)
    else:
        if output_probabilities:
            return np.average(at_least_one_fail), probabilities
        else:
            return np.average(at_least_one_fail)
Exemplo n.º 24
0
    def sample_chains(self,
                      n_sample,
                      init_states,
                      trace_funcs,
                      n_process=1,
                      **kwargs):
        """Sample one or more Markov chains from given initial states.

        Performs a specified number of chain iterations (each of which may be
        composed of multiple individual Markov transitions), recording the
        outputs of functions of the sampled chain state after each iteration.
        The chains may be run in parallel across multiple independent processes
        or sequentially. In all cases all chains use independent random draws.

        Args:
            n_sample (int): Number of samples (iterations) to draw per chain.
            init_states (Iterable[ChainState] or \
                         Iterable[Dict[str, object]]): Initial chain states.
                Each entry can be either a `ChainState` object or a dictionary
                with entries specifying initial values for all state variables
                used by chain transition `sample` methods.
            trace_funcs (Iterable[Callable[[ChainState],\
                                  Dict[str, array or float]]]): List of
                functions which compute the variables to be recorded at each
                chain iteration, with each trace function being passed the
                current state and returning a dictionary of scalar or array
                values corresponding to the variable(s) to be stored. The keys
                in the returned dictionaries are used to index the trace arrays
                in the returned traces dictionary. If a key appears in multiple
                dictionaries only the the value corresponding to the last trace
                function to return that key will be stored.
            n_process (int or None): Number of parallel processes to run chains
                over. If set to one then chains will be run sequentially in
                otherwise a `multiprocessing.Pool` object will be used to
                dynamically assign the chains across multiple processes. If
                set to `None` then the number of processes will default to the
                output of `os.cpu_count()`.

        Kwargs:
            memmap_enabled (bool): Whether to memory-map arrays used to store
                chain data to files on disk to avoid excessive system memory
                usage for long chains and/or large chain states. The chain data
                is written to `.npy` files in the directory specified by
                `memmap_path` (or a temporary directory if not provided). These
                files persist after the termination of the function so should
                be manually deleted when no longer required. Default is to
                for memory mapping to be disabled.
            memmap_path (str): Path to directory to write memory-mapped chain
                data to. If not provided, a temporary directory will be created
                and the chain data written to files there.
            monitor_stats (Iterable[Tuple[str, str]]): List of tuples of string
                key pairs, with first entry the key of a Markov transition in
                the `transitions` dict passed to the the `__init__` method and
                the second entry the key of a chain statistic that will be
                returned in the `chain_stats` dictionary. The mean over samples
                computed so far of the chain statistics associated with any
                valid key-pairs will be monitored during sampling  by printing
                as postfix to progress bar (if `tqdm` is installed).

        Returns:
            final_states (List[ChainState]): States of chains after final
                iteration. May be used to resume sampling a chain by passing as
                the initial states to a new `sample_chains` call.
            traces (Dict[str, List[array]]): Dictionary of chain trace arrays.
                Values in dictionary are list of arrays of variables outputted
                by trace functions in `trace_funcs` with each array in the list
                corresponding to a single chain and the leading dimension of
                each array corresponding to the sampling (draw) index. The key
                for each value is the corresponding key in the dictionary
                returned by the trace function which computed the traced value.
            chain_stats (Dict[str, Dict[str, List[array]]]): Dictionary of
                chain transition statistic dictionaries. Values in outer
                dictionary are dictionaries of statistics for each chain
                transition, keyed by the string key for the transition. The
                values in each inner transition dictionary are lists of arrays
                of chain statistic values with each array in the list
                corresponding to a single chain and the leading dimension of
                each array corresponding to the sampling (draw) index. The key
                for each value is a string description of the corresponding
                integration transition statistic.
        """
        n_chain = len(init_states)
        kwargs = self.__preprocess_kwargs(kwargs)
        if RANDOMGEN_AVAILABLE:
            seed = self.rng.randint(2**64, dtype='uint64')
            rngs = [
                randomgen.Xorshift1024(seed).jump(i).generator
                for i in range(n_chain)
            ]
        else:
            seeds = (self.rng.choice(2**16, n_chain, False) * 2**16 +
                     self.rng.choice(2**16, n_chain, False))
            rngs = [np.random.RandomState(seed) for seed in seeds]
        chain_outputs = []
        shared_kwargs_list = [{
            'rng': rng,
            'n_sample': n_sample,
            'init_state': init_state,
            'trace_funcs': trace_funcs,
            'chain_index': c,
            **kwargs
        } for c, (rng, init_state) in enumerate(zip(rngs, init_states))]
        if n_process == 1:
            # Using single process therefore run chains sequentially
            for c, shared_kwargs in enumerate(shared_kwargs_list):
                final_state, traces, stats, sample_index = self._sample_chain(
                    **shared_kwargs, parallel_chains=False)
                chain_outputs.append(
                    (final_state, traces, stats, sample_index))
                if sample_index != n_sample:
                    logger.error(
                        f'Sampling manually interrupted at chain {c} iteration'
                        f' {sample_index}. Arrays containing chain traces'
                        f' and statistics computed before interruption will'
                        f' be returned.')
                    break
        else:
            # Run chains in parallel using a multiprocess(ing).Pool
            # Child processes made to ignore SIGINT signals to allow handling
            # of KeyboardInterrupts in parent process
            n_completed = 0
            pool = Pool(n_process, _ignore_sigint_initialiser)
            try:
                results = [
                    pool.apply_async(self._sample_chain,
                                     kwds=dict(**shared_kwargs,
                                               parallel_chains=True))
                    for shared_kwargs in shared_kwargs_list
                ]
                for result in results:
                    chain_outputs.append(result.get())
                    n_completed += 1
            except KeyboardInterrupt:
                # Close any still running processes
                pool.terminate()
                pool.join()
                err_message = 'Sampling manually interrupted.'
                if n_completed > 0:
                    err_message += (
                        f' Data for {n_completed} completed chains will be '
                        f'returned.')
                if kwargs['memmap_enabled']:
                    err_message += (
                        f' All data recorded so far including in progress '
                        f'chains is available in directory '
                        f'{kwargs["memmap_path"]}.')
                logger.error(err_message)
        # When running parallel jobs with memory-mapping enabled, data arrays
        # returned by processes as file paths to array memory-maps therfore
        # load memory-maps objects from file before returing results
        load_memmaps = kwargs['memmap_enabled'] and n_process > 1
        return self._collate_chain_outputs(n_sample, chain_outputs,
                                           load_memmaps)
Exemplo n.º 25
0
    pprint(lunzhuan_exp_list)

    for para in lunzhuan_exp_list:
        if 'exp01' in para:
            exp_list.append(exp01)
        elif 'exp09' in para:
            exp_list.append(exp09)
        elif 'exp10' in para:
            exp_list.append(exp10)
        elif 'exp19' in para:
            exp_list.append(exp19)
        elif 'exp21' in para:
            exp_list.append(exp21)
        # elif 'exp22' in para:
        #     exp_list.append(exp22)
        # elif 'exp23' in para:
        #     exp_list.append(exp23)
        # elif 'exp24' in para:
        #     exp_list.append(exp24)
        else:
            print "Wrong exp id"
    print "exp_list: ", exp_list
    pool = Pool(4)
    for exp, para in zip(exp_list, lunzhuan_exp_list):
        pool.apply_async(exp, (para, ))
    # pool.map(exp_list, lunzhuan_exp_list)
    pool.close()
    pool.join()
    #

    # print concat_names(test_prefix_dir, exp_infos)
Exemplo n.º 26
0
def main():
    # Create logger and stream handler
    logger = logging.getLogger("cite_seq_count")
    logger.setLevel(logging.CRITICAL)
    ch = logging.StreamHandler()
    ch.setLevel(logging.CRITICAL)
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    start_time = time.time()
    parser = get_args()
    if not sys.argv[1:]:
        parser.print_help(file=sys.stderr)
        sys.exit(2)

    # Parse arguments.
    args = parser.parse_args()
    if args.whitelist:
        print("Loading whitelist")
        (whitelist, args.bc_threshold) = preprocessing.parse_whitelist_csv(
            filename=args.whitelist,
            barcode_length=args.cb_last - args.cb_first + 1,
            collapsing_threshold=args.bc_threshold,
        )
    else:
        whitelist = False

    # Load TAGs/ABs.
    ab_map = preprocessing.parse_tags_csv(args.tags)
    ab_map = preprocessing.check_tags(ab_map, args.max_error)

    # Identify input file(s)
    read1_paths, read2_paths = preprocessing.get_read_paths(
        args.read1_path, args.read2_path
    )

    # preprocessing and processing occur in separate loops so the program can crash earlier if
    # one of the inputs is not valid.
    read1_lengths = []
    read2_lengths = []
    for read1_path, read2_path in zip(read1_paths, read2_paths):
        # Get reads length. So far, there is no validation for Read2.
        read1_lengths.append(preprocessing.get_read_length(read1_path))
        read2_lengths.append(preprocessing.get_read_length(read2_path))
        # Check Read1 length against CELL and UMI barcodes length.
        (
            barcode_slice,
            umi_slice,
            barcode_umi_length,
        ) = preprocessing.check_barcodes_lengths(
            read1_lengths[-1],
            args.cb_first,
            args.cb_last,
            args.umi_first,
            args.umi_last,
        )
    # Ensure all files have the same input length
    # if len(set(read1_lengths)) != 1:
    #    sys.exit('Input barcode fastqs (read1) do not all have same length.\nExiting')

    # Initialize the counts dicts that will be generated from each input fastq pair
    final_results = defaultdict(lambda: defaultdict(Counter))
    umis_per_cell = Counter()
    reads_per_cell = Counter()
    merged_no_match = Counter()
    number_of_samples = len(read1_paths)
    n_reads = 0

    # Print a statement if multiple files are run.
    if number_of_samples != 1:
        print("Detected {} files to run on.".format(number_of_samples))

    for read1_path, read2_path in zip(read1_paths, read2_paths):
        if args.first_n:
            n_lines = (args.first_n * 4) / number_of_samples
        else:
            n_lines = preprocessing.get_n_lines(read1_path)
        n_reads += int(n_lines / 4)
        n_threads = args.n_threads
        print("Started mapping")
        print("Processing {:,} reads".format(n_reads))
        # Run with one process
        if n_threads <= 1 or n_reads < 1000001:
            print("CITE-seq-Count is running with one core.")
            (_final_results, _merged_no_match) = processing.map_reads(
                read1_path=read1_path,
                read2_path=read2_path,
                tags=ab_map,
                barcode_slice=barcode_slice,
                umi_slice=umi_slice,
                indexes=[0, n_reads],
                whitelist=whitelist,
                debug=args.debug,
                start_trim=args.start_trim,
                maximum_distance=args.max_error,
                sliding_window=args.sliding_window,
            )
            print("Mapping done")
            _umis_per_cell = Counter()
            _reads_per_cell = Counter()
            for cell_barcode, counts in _final_results.items():
                _umis_per_cell[cell_barcode] = sum([len(counts[UMI]) for UMI in counts])
                _reads_per_cell[cell_barcode] = sum(
                    [sum(counts[UMI].values()) for UMI in counts]
                )
        else:
            # Run with multiple processes
            print("CITE-seq-Count is running with {} cores.".format(n_threads))
            p = Pool(processes=n_threads)
            chunk_indexes = preprocessing.chunk_reads(n_reads, n_threads)
            parallel_results = []

            for indexes in chunk_indexes:
                p.apply_async(
                    processing.map_reads,
                    args=(
                        read1_path,
                        read2_path,
                        ab_map,
                        barcode_slice,
                        umi_slice,
                        indexes,
                        whitelist,
                        args.debug,
                        args.start_trim,
                        args.max_error,
                        args.sliding_window,
                    ),
                    callback=parallel_results.append,
                    error_callback=sys.stderr,
                )
            p.close()
            p.join()
            print("Mapping done")
            print("Merging results")

            (
                _final_results,
                _umis_per_cell,
                _reads_per_cell,
                _merged_no_match,
            ) = processing.merge_results(parallel_results=parallel_results)
            del parallel_results

        # Update the overall counts dicts
        umis_per_cell.update(_umis_per_cell)
        reads_per_cell.update(_reads_per_cell)
        merged_no_match.update(_merged_no_match)
        for cell_barcode in _final_results:
            for tag in _final_results[cell_barcode]:
                if tag in final_results[cell_barcode]:
                    # Counter + Counter = Counter
                    final_results[cell_barcode][tag] += _final_results[cell_barcode][
                        tag
                    ]
                else:
                    # Explicitly save the counter to that tag
                    final_results[cell_barcode][tag] = _final_results[cell_barcode][tag]
    ordered_tags_map = OrderedDict()
    for i, tag in enumerate(ab_map.values()):
        ordered_tags_map[tag] = i
    ordered_tags_map["unmapped"] = i + 1

    # Correct cell barcodes
    if args.bc_threshold > 0:
        if len(umis_per_cell) <= args.expected_cells:
            print(
                "Number of expected cells, {}, is higher "
                "than number of cells found {}.\nNot performing"
                "cell barcode correction"
                "".format(args.expected_cells, len(umis_per_cell))
            )
            bcs_corrected = 0
        else:
            print("Correcting cell barcodes")
            if not whitelist:
                (
                    final_results,
                    umis_per_cell,
                    bcs_corrected,
                ) = processing.correct_cells(
                    final_results=final_results,
                    reads_per_cell=reads_per_cell,
                    umis_per_cell=umis_per_cell,
                    expected_cells=args.expected_cells,
                    collapsing_threshold=args.bc_threshold,
                    ab_map=ordered_tags_map,
                )
            else:
                (
                    final_results,
                    umis_per_cell,
                    bcs_corrected,
                ) = processing.correct_cells_whitelist(
                    final_results=final_results,
                    umis_per_cell=umis_per_cell,
                    whitelist=whitelist,
                    collapsing_threshold=args.bc_threshold,
                    ab_map=ordered_tags_map,
                )
    else:
        bcs_corrected = 0

    # If given, use whitelist for top cells
    if whitelist:
        top_cells = whitelist
        # Add potential missing cell barcodes.
        for missing_cell in whitelist:
            if missing_cell in final_results:
                continue
            else:
                final_results[missing_cell] = dict()
                for TAG in ordered_tags_map:
                    final_results[missing_cell][TAG] = Counter()
                top_cells.add(missing_cell)
    else:
        # Select top cells based on total umis per cell
        top_cells_tuple = umis_per_cell.most_common(args.expected_cells)
        top_cells = set([pair[0] for pair in top_cells_tuple])

    # UMI correction

    if args.no_umi_correction:
        # Don't correct
        umis_corrected = 0
        aberrant_cells = set()
    else:
        # Correct UMIS
        (final_results, umis_corrected, aberrant_cells) = processing.correct_umis(
            final_results=final_results,
            collapsing_threshold=args.umi_threshold,
            top_cells=top_cells,
            max_umis=20000,
        )

    # Remove aberrant cells from the top cells
    for cell_barcode in aberrant_cells:
        top_cells.remove(cell_barcode)

    # Create sparse aberrant cells matrix
    (umi_aberrant_matrix, read_aberrant_matrix) = processing.generate_sparse_matrices(
        final_results=final_results,
        ordered_tags_map=ordered_tags_map,
        top_cells=aberrant_cells,
    )

    # Write uncorrected cells to dense output
    io.write_dense(
        sparse_matrix=umi_aberrant_matrix,
        index=list(ordered_tags_map.keys()),
        columns=aberrant_cells,
        outfolder=os.path.join(args.outfolder, "uncorrected_cells"),
        filename="dense_umis.tsv",
    )

    # Create sparse matrices for results
    (umi_results_matrix, read_results_matrix) = processing.generate_sparse_matrices(
        final_results=final_results,
        ordered_tags_map=ordered_tags_map,
        top_cells=top_cells,
    )

    # Write umis to file
    io.write_to_files(
        sparse_matrix=umi_results_matrix,
        top_cells=top_cells,
        ordered_tags_map=ordered_tags_map,
        data_type="umi",
        outfolder=args.outfolder,
    )

    # Write reads to file
    io.write_to_files(
        sparse_matrix=read_results_matrix,
        top_cells=top_cells,
        ordered_tags_map=ordered_tags_map,
        data_type="read",
        outfolder=args.outfolder,
    )

    # Write unmapped sequences
    io.write_unmapped(
        merged_no_match=merged_no_match,
        top_unknowns=args.unknowns_top,
        outfolder=args.outfolder,
        filename=args.unmapped_file,
    )

    # Create report and write it to disk
    create_report(
        n_reads=n_reads,
        reads_per_cell=reads_per_cell,
        no_match=merged_no_match,
        version=version,
        start_time=start_time,
        ordered_tags_map=ordered_tags_map,
        umis_corrected=umis_corrected,
        bcs_corrected=bcs_corrected,
        bad_cells=aberrant_cells,
        args=args,
    )

    # Write dense matrix to disk if requested
    if args.dense:
        print("Writing dense format output")
        io.write_dense(
            sparse_matrix=umi_results_matrix,
            index=list(ordered_tags_map.keys()),
            columns=top_cells,
            outfolder=args.outfolder,
            filename="dense_umis.tsv",
        )
Exemplo n.º 27
0
def main():
    #Create logger and stream handler
    logger = logging.getLogger('cite_seq_count')
    logger.setLevel(logging.CRITICAL)
    ch = logging.StreamHandler()
    ch.setLevel(logging.CRITICAL)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    start_time = time.time()
    parser = get_args()
    if not sys.argv[1:]:
        parser.print_help(file=sys.stderr)
        sys.exit(2)

    # Parse arguments.
    args = parser.parse_args()
    if args.whitelist:
        (whitelist, args.bc_threshold) = preprocessing.parse_whitelist_csv(
            filename=args.whitelist,
            barcode_length=args.cb_last - args.cb_first + 1,
            collapsing_threshold=args.bc_threshold)
    else:
        whitelist = False

    # Load TAGs/ABs.
    ab_map = preprocessing.parse_tags_csv(args.tags)
    ab_map = preprocessing.check_tags(ab_map, args.max_error)
    # Get reads length. So far, there is no validation for Read2.
    read1_length = preprocessing.get_read_length(args.read1_path)
    read2_length = preprocessing.get_read_length(args.read2_path)
    # Check Read1 length against CELL and UMI barcodes length.
    (barcode_slice, umi_slice,
     barcode_umi_length) = preprocessing.check_barcodes_lengths(
         read1_length, args.cb_first, args.cb_last, args.umi_first,
         args.umi_last)

    if args.first_n:
        n_lines = args.first_n * 4
    else:
        n_lines = preprocessing.get_n_lines(args.read1_path)
    n_reads = int(n_lines / 4)
    n_threads = args.n_threads
    print('Started mapping')
    print('Processing {:,} reads'.format(n_reads))
    #Run with one process
    if n_threads <= 1 or n_reads < 1000001:
        print('CITE-seq-Count is running with one core.')
        (final_results, merged_no_match) = processing.map_reads(
            read1_path=args.read1_path,
            read2_path=args.read2_path,
            tags=ab_map,
            barcode_slice=barcode_slice,
            umi_slice=umi_slice,
            indexes=[0, n_reads],
            whitelist=whitelist,
            debug=args.debug,
            start_trim=args.start_trim,
            maximum_distance=args.max_error,
            sliding_window=args.sliding_window)
        print('Mapping done')
        umis_per_cell = Counter()
        reads_per_cell = Counter()
        for cell_barcode, counts in final_results.items():
            umis_per_cell[cell_barcode] = sum(
                [len(counts[UMI]) for UMI in counts])
            reads_per_cell[cell_barcode] = sum(
                [sum(counts[UMI].values()) for UMI in counts])
    else:
        # Run with multiple processes
        print('CITE-seq-Count is running with {} cores.'.format(n_threads))
        p = Pool(processes=n_threads)
        chunk_indexes = preprocessing.chunk_reads(n_reads, n_threads)
        parallel_results = []

        for indexes in chunk_indexes:
            p.apply_async(processing.map_reads,
                          args=(args.read1_path, args.read2_path, ab_map,
                                barcode_slice, umi_slice, indexes, whitelist,
                                args.debug, args.start_trim, args.max_error,
                                args.sliding_window),
                          callback=parallel_results.append,
                          error_callback=sys.stderr)
        p.close()
        p.join()
        print('Mapping done')
        print('Merging results')

        (final_results, umis_per_cell, reads_per_cell,
         merged_no_match) = processing.merge_results(
             parallel_results=parallel_results)
        del (parallel_results)

    ordered_tags_map = OrderedDict()
    for i, tag in enumerate(ab_map.values()):
        ordered_tags_map[tag] = i
    ordered_tags_map['unmapped'] = i + 1

    # Correct cell barcodes
    if (len(umis_per_cell) <= args.expected_cells):
        print("Number of expected cells, {}, is higher " \
            "than number of cells found {}.\nNot performing" \
            "cell barcode correction" \
            "".format(args.expected_cells, len(umis_per_cell)))
        bcs_corrected = 0
    else:
        print('Correcting cell barcodes')
        if not whitelist:
            (final_results, umis_per_cell,
             bcs_corrected) = processing.correct_cells(
                 final_results=final_results,
                 umis_per_cell=umis_per_cell,
                 expected_cells=args.expected_cells,
                 collapsing_threshold=args.bc_threshold)
        else:
            (final_results, umis_per_cell,
             bcs_corrected) = processing.correct_cells_whitelist(
                 final_results=final_results,
                 umis_per_cell=umis_per_cell,
                 whitelist=whitelist,
                 collapsing_threshold=args.bc_threshold)

    # Correct umi barcodes
    if not whitelist:
        top_cells_tuple = umis_per_cell.most_common(args.expected_cells)
        top_cells = set([pair[0] for pair in top_cells_tuple])

    # Sort cells by number of mapped umis
    else:
        top_cells = whitelist
        # Add potential missing cell barcodes.
        for missing_cell in whitelist:
            if missing_cell in final_results:
                continue
            else:
                final_results[missing_cell] = dict()
                for TAG in ordered_tags_map:
                    final_results[missing_cell][TAG] = Counter()
                top_cells.add(missing_cell)
    #If we want umi correction
    if not args.no_umi_correction:
        (final_results, umis_corrected,
         aberrant_cells) = processing.correct_umis(
             final_results=final_results,
             collapsing_threshold=args.umi_threshold,
             top_cells=top_cells,
             max_umis=20000)
    else:
        umis_corrected = 0
        aberrant_cells = set()
    for cell_barcode in aberrant_cells:
        top_cells.remove(cell_barcode)
    #Create sparse aberrant cells matrix
    (umi_aberrant_matrix,
     read_aberrant_matrix) = processing.generate_sparse_matrices(
         final_results=final_results,
         ordered_tags_map=ordered_tags_map,
         top_cells=aberrant_cells)

    #Write uncorrected cells to dense output
    io.write_dense(sparse_matrix=umi_aberrant_matrix,
                   index=list(ordered_tags_map.keys()),
                   columns=aberrant_cells,
                   outfolder=os.path.join(args.outfolder, 'uncorrected_cells'),
                   filename='dense_umis.tsv')

    (umi_results_matrix,
     read_results_matrix) = processing.generate_sparse_matrices(
         final_results=final_results,
         ordered_tags_map=ordered_tags_map,
         top_cells=top_cells)
    # Write umis to file
    io.write_to_files(sparse_matrix=umi_results_matrix,
                      top_cells=top_cells,
                      ordered_tags_map=ordered_tags_map,
                      data_type='umi',
                      outfolder=args.outfolder)
    # Write reads to file
    io.write_to_files(sparse_matrix=read_results_matrix,
                      top_cells=top_cells,
                      ordered_tags_map=ordered_tags_map,
                      data_type='read',
                      outfolder=args.outfolder)

    top_unmapped = merged_no_match.most_common(args.unknowns_top)

    with open(os.path.join(args.outfolder, args.unmapped_file),
              'w') as unknown_file:
        unknown_file.write('tag,count\n')
        for element in top_unmapped:
            unknown_file.write('{},{}\n'.format(element[0], element[1]))
    create_report(n_reads=n_reads,
                  reads_per_cell=reads_per_cell,
                  no_match=merged_no_match,
                  version=version,
                  start_time=start_time,
                  ordered_tags_map=ordered_tags_map,
                  umis_corrected=umis_corrected,
                  bcs_corrected=bcs_corrected,
                  bad_cells=aberrant_cells,
                  args=args)
    if args.dense:
        print('Writing dense format output')
        io.write_dense(sparse_matrix=umi_results_matrix,
                       index=list(ordered_tags_map.keys()),
                       columns=top_cells,
                       outfolder=args.outfolder,
                       filename='dense_umis.tsv')
Exemplo n.º 28
0
def test():
    print('cpuCount() = %d\n' % cpuCount())

    #
    # Create pool
    #

    PROCESSES = 4
    print('Creating pool with %d processes\n' % PROCESSES)
    pool = Pool(PROCESSES)

    #
    # Tests
    #

    TASKS = [(mul, (i, 7)) for i in range(10)] + \
            [(plus, (i, 8)) for i in range(10)]

    results = [pool.apply_async(calculate, t) for t in TASKS]
    imap_it = pool.imap(calculatestar, TASKS)
    imap_unordered_it = pool.imap_unordered(calculatestar, TASKS)

    print('Ordered results using pool.apply_async():')
    for r in results:
        print('\t', r.get())
    print()

    print('Ordered results using pool.imap():')
    for x in imap_it:
        print('\t', x)
    print()

    print('Unordered results using pool.imap_unordered():')
    for x in imap_unordered_it:
        print('\t', x)
    print()

    print('Ordered results using pool.map() --- will block till complete:')
    for x in pool.map(calculatestar, TASKS):
        print('\t', x)
    print()

    #
    # Simple benchmarks
    #

    N = 100000
    print('def pow3(x): return x**3')

    t = time.time()
    A = list(map(pow3, range(N)))
    print('\tmap(pow3, range(%d)):\n\t\t%s seconds' % \
          (N, time.time() - t))

    t = time.time()
    B = pool.map(pow3, range(N))
    print('\tpool.map(pow3, range(%d)):\n\t\t%s seconds' % \
          (N, time.time() - t))

    t = time.time()
    C = list(pool.imap(pow3, range(N), chunksize=N // 8))
    print('\tlist(pool.imap(pow3, range(%d), chunksize=%d)):\n\t\t%s' \
          ' seconds' % (N, N//8, time.time() - t))

    assert A == B == C, (len(A), len(B), len(C))
    print()

    L = [None] * 1000000
    print('def noop(x): pass')
    print('L = [None] * 1000000')

    t = time.time()
    A = list(map(noop, L))
    print('\tmap(noop, L):\n\t\t%s seconds' % \
          (time.time() - t))

    t = time.time()
    B = pool.map(noop, L)
    print('\tpool.map(noop, L):\n\t\t%s seconds' % \
          (time.time() - t))

    t = time.time()
    C = list(pool.imap(noop, L, chunksize=len(L) // 8))
    print('\tlist(pool.imap(noop, L, chunksize=%d)):\n\t\t%s seconds' % \
          (len(L)//8, time.time() - t))

    assert A == B == C, (len(A), len(B), len(C))
    print()

    del A, B, C, L

    #
    # Test error handling
    #

    print('Testing error handling:')

    try:
        print(pool.apply(f, (5, )))
    except ZeroDivisionError:
        print('\tGot ZeroDivisionError as expected from pool.apply()')
    else:
        raise AssertionError('expected ZeroDivisionError')

    try:
        print(pool.map(f, range(10)))
    except ZeroDivisionError:
        print('\tGot ZeroDivisionError as expected from pool.map()')
    else:
        raise AssertionError('expected ZeroDivisionError')

    try:
        print(list(pool.imap(f, range(10))))
    except ZeroDivisionError:
        print('\tGot ZeroDivisionError as expected from list(pool.imap())')
    else:
        raise AssertionError('expected ZeroDivisionError')

    it = pool.imap(f, range(10))
    for i in range(10):
        try:
            x = it.next()
        except ZeroDivisionError:
            if i == 5:
                pass
        except StopIteration:
            break
        else:
            if i == 5:
                raise AssertionError('expected ZeroDivisionError')

    assert i == 9
    print('\tGot ZeroDivisionError as expected from IMapIterator.next()')
    print()

    #
    # Testing timeouts
    #

    print('Testing ApplyResult.get() with timeout:', end='')
    res = pool.apply_async(calculate, TASKS[0])
    while 1:
        sys.stdout.flush()
        try:
            sys.stdout.write('\n\t%s' % res.get(0.02))
            break
        except TimeoutError:
            sys.stdout.write('.')
    print()
    print()

    print('Testing IMapIterator.next() with timeout:', end='')
    it = pool.imap(calculatestar, TASKS)
    while 1:
        sys.stdout.flush()
        try:
            sys.stdout.write('\n\t%s' % it.next(0.02))
        except StopIteration:
            break
        except TimeoutError:
            sys.stdout.write('.')
    print()
    print()

    #
    # Testing callback
    #

    print('Testing callback:')

    A = []
    B = [56, 0, 1, 8, 27, 64, 125, 216, 343, 512, 729]

    r = pool.apply_async(mul, (7, 8), callback=A.append)
    r.wait()

    r = pool.map_async(pow3, range(10), callback=A.extend)
    r.wait()

    if A == B:
        print('\tcallbacks succeeded\n')
    else:
        print('\t*** callbacks failed\n\t\t%s != %s\n' % (A, B))

    #
    # Check there are no outstanding tasks
    #

    assert not pool._cache, 'cache = %r' % pool._cache

    #
    # Check close() methods
    #

    print('Testing close():')

    for worker in pool._pool:
        assert worker.is_alive()

    result = pool.apply_async(time.sleep, [0.5])
    pool.close()
    pool.join()

    assert result.get() is None

    for worker in pool._pool:
        assert not worker.is_alive()

    print('\tclose() succeeded\n')

    #
    # Check terminate() method
    #

    print('Testing terminate():')

    pool = Pool(2)
    ignore = pool.apply(pow3, [2])
    results = [pool.apply_async(time.sleep, [10]) for i in range(10)]
    pool.terminate()
    pool.join()

    for worker in pool._pool:
        assert not worker.is_alive()

    print('\tterminate() succeeded\n')

    #
    # Check garbage collection
    #

    print('Testing garbage collection:')

    pool = Pool(2)
    processes = pool._pool

    ignore = pool.apply(pow3, [2])
    results = [pool.apply_async(time.sleep, [10]) for i in range(10)]

    del results, pool

    time.sleep(0.2)

    for worker in processes:
        assert not worker.is_alive()

    print('\tgarbage collection succeeded\n')
Exemplo n.º 29
0
class GroupCheckerGui(BaseWidget):

    def __init__(self):
        super(GroupCheckerGui, self).__init__('Group Checker')

        self._group_name = ControlText('Group Name', CONFIG['group_name'])
        self._group_name.enabled = False
        self._allowed_tags = UnicodeControlList('Allowed Tags',
                                               plusFunction=self.__add_tag_action,
                                               minusFunction=self.__remove_tag_action)
        self.allowed_tags = GuiList(CONFIG['white_filters']['SubstringFilter']['substrings'],
                                    self._allowed_tags)

        self._allowed_ids = ControlList('Allowed Ids',
                                        plusFunction=self.__add_id_action,
                                        minusFunction=self.__remove_id_action)
        self.allowed_ids = GuiList(CONFIG['white_filters']['SignerFilter']['ids'], self._allowed_ids)

        self._bad_posts = ControlCheckBoxList('Bad posts')
        self._bad_posts._form.listWidget.itemDoubleClicked.connect(self.__show_link_action)

        self._remove_button = ControlButton('Remove')
        self._remove_button.value = self.__remove_action

        self._show_button = ControlButton('Show bad posts')
        self._show_button.value = self.__show_bad_post_action

        self.pool = Pool(processes=1)
        self.bad_posts = []

        self._formset = [('', '_group_name', ''),
                         ('', '_allowed_tags', '_allowed_ids', ''),
                         '',
                         ('', '_bad_posts', ''),
                         ('', '_remove_button', '_show_button', ''),
                         '']

    def __add_tag_action(self):
        win = PopUpGetText('tag', self.allowed_tags)
        win.show()

    def __remove_tag_action(self):
        self.allowed_tags.remove(self._allowed_tags.mouseSelectedRowIndex)

    def __add_id_action(self):
        win = PopUpGetText('id', self.allowed_ids)
        win.show()

    def __remove_id_action(self):
        self.allowed_ids.remove(self._allowed_ids.mouseSelectedRowIndex)

    def __show_bad_post_action(self):
        def callback(posts):
            self.bad_posts = posts
            self._bad_posts.value = [(GroupBot.get_link_from_post(post, CONFIG['group_name']), True) for post in posts]
            self._show_button.enabled = True

        def run_bot():
            bot = create_bot_from_config()
            return bot.get_bad_posts()

        self._show_button.enabled = False
        self.pool.apply_async(run_bot, callback=callback)

    def __show_link_action(self, link):
        webbrowser.open(link.text())

    def __remove_action(self):
        checked_posts = [self.bad_posts[idx] for idx in self._bad_posts.checkedIndexes]
        bot = create_bot_from_config()
        bot.remove_posts(checked_posts)