예제 #1
0
    def test_abort(self):
        # Trigger a ridiculous amount of tasks, and abort the remaining.
        with threading_utils.ThreadPool(2, 2, 0) as pool:
            # Allow 10 tasks to run initially.
            sem = threading.Semaphore(10)

            def grab_and_return(x):
                sem.acquire()
                return x

            for i in range(100):
                pool.add_task(0, grab_and_return, i)

            # Running at 11 would hang.
            results = [pool.get_one_result() for _ in xrange(10)]
            # At that point, there's 10 completed tasks and 2 tasks hanging, 88
            # pending.
            self.assertEqual(88, pool.abort())
            # Calling .join() before these 2 .release() would hang.
            sem.release()
            sem.release()
            results.extend(pool.join())
        # The results *may* be out of order. Even if the calls are processed
        # strictly in FIFO mode, a thread may preempt another one when returning the
        # values.
        self.assertEqual(range(12), sorted(results))
예제 #2
0
 def test_send_exception_raises_exception(self):
   class CustomError(Exception):
     pass
   with threading_utils.ThreadPool(1, 1, 0) as tp:
     channel = threading_utils.TaskChannel()
     tp.add_task(0, lambda: channel.send_exception(CustomError()))
     with self.assertRaises(CustomError):
       channel.pull()
 def test_timeout_exception_from_task(self):
   with threading_utils.ThreadPool(1, 1, 0) as tp:
     channel = threading_utils.TaskChannel()
     def task_func():
       raise threading_utils.TaskChannel.Timeout()
     tp.add_task(0, channel.wrap_task(task_func))
     # 'Timeout' raised by task gets transformed into 'RuntimeError'.
     with self.assertRaises(RuntimeError):
       channel.next()
예제 #4
0
 def test_wrap_task_raises_exception(self):
   class CustomError(Exception):
     pass
   with threading_utils.ThreadPool(1, 1, 0) as tp:
     channel = threading_utils.TaskChannel()
     def task_func():
       raise CustomError()
     tp.add_task(0, channel.wrap_task(task_func))
     with self.assertRaises(CustomError):
       channel.pull()
예제 #5
0
    def test_send_exception_raises_exception(self):
        class CustomError(Exception):
            pass

        with threading_utils.ThreadPool(1, 1, 0) as tp:
            channel = threading_utils.TaskChannel()
            exc_info = (CustomError, CustomError(), None)
            tp.add_task(0, lambda: channel.send_exception(exc_info))
            with self.assertRaises(CustomError):
                next(channel)
 def test_next_timeout(self):
   with threading_utils.ThreadPool(1, 1, 0) as tp:
     channel = threading_utils.TaskChannel()
     def task_func():
       # This test ultimately relies on the condition variable primitive
       # provided by pthreads. There's no easy way to mock time for it.
       # Increase this duration if the test is flaky.
       time.sleep(0.2)
       return 123
     tp.add_task(0, channel.wrap_task(task_func))
     with self.assertRaises(threading_utils.TaskChannel.Timeout):
       channel.next(timeout=0.001)
     self.assertEqual(123, channel.next())
 def test_wrap_task_exception_captures_stack_trace(self):
   class CustomError(Exception):
     pass
   with threading_utils.ThreadPool(1, 1, 0) as tp:
     channel = threading_utils.TaskChannel()
     def task_func():
       def function_with_some_unusual_name():
         raise CustomError()
       function_with_some_unusual_name()
     tp.add_task(0, channel.wrap_task(task_func))
     exc_traceback = ''
     try:
       channel.next()
     except CustomError:
       exc_traceback = traceback.format_exc()
     self.assertIn('function_with_some_unusual_name', exc_traceback)
def _fetch_daily_internal(delta, swarming, process, endpoint, start, end,
                          state, tags, parallel):
    """Executes 'process' by parallelizing it once per day."""
    out = {}
    with threading_utils.ThreadPool(1, parallel, 0) as pool:
        while start < end:
            cmd = _get_cmd(swarming, endpoint, _get_epoch(start),
                           _get_epoch(start + delta), state, tags)
            pool.add_task(0, _run_json, start.strftime('%Y-%m-%d'), process,
                          cmd)
            start += delta
        for k, v in pool.iter_results():
            sys.stdout.write('.')
            sys.stdout.flush()
            out[k] = v
    print('')
    return out
예제 #9
0
    def test_priority(self):
        # Verifies that a lower priority is run first.
        with threading_utils.ThreadPool(1, 1, 0) as pool:
            lock = threading.Lock()

            def wait_and_return(x):
                with lock:
                    return x

            def return_x(x):
                return x

            with lock:
                pool.add_task(0, wait_and_return, 'a')
                pool.add_task(2, return_x, 'b')
                pool.add_task(1, return_x, 'c')

            actual = pool.join()
        self.assertEqual(['a', 'c', 'b'], actual)
예제 #10
0
def yield_results(swarm_base_url, test_keys, timeout, max_threads):
  """Yields swarm test results from the swarm server as (index, result).

  Duplicate shards are ignored, the first one to complete is returned.

  max_threads is optional and is used to limit the number of parallel fetches
  done. Since in general the number of test_keys is in the range <=10, it's not
  worth normally to limit the number threads. Mostly used for testing purposes.
  """
  shards_remaining = range(len(test_keys))
  number_threads = (
      min(max_threads, len(test_keys)) if max_threads else len(test_keys))
  should_stop = threading_utils.Bit()
  results_remaining = len(test_keys)
  with threading_utils.ThreadPool(number_threads, number_threads, 0) as pool:
    try:
      for test_key in test_keys:
        pool.add_task(
            0, retrieve_results, swarm_base_url, test_key, timeout, should_stop)
      while shards_remaining and results_remaining:
        result = pool.get_one_result()
        results_remaining -= 1
        if not result:
          # Failed to retrieve one key.
          logging.error('Failed to retrieve the results for a swarm key')
          continue
        shard_index = result['config_instance_index']
        if shard_index in shards_remaining:
          shards_remaining.remove(shard_index)
          yield shard_index, result
        else:
          logging.warning('Ignoring duplicate shard index %d', shard_index)
          # Pop the last entry, there's no such shard.
          shards_remaining.pop()
    finally:
      # Done, kill the remaining threads.
      should_stop.set()
예제 #11
0
 def test_wrap_task_passes_exception_value(self):
     with threading_utils.ThreadPool(1, 1, 0) as tp:
         channel = threading_utils.TaskChannel()
         tp.add_task(0, channel.wrap_task(lambda: Exception()))
         self.assertTrue(isinstance(channel.pull(), Exception))
예제 #12
0
 def test_wrap_task_passes_simple_value(self):
     with threading_utils.ThreadPool(1, 1, 0) as tp:
         channel = threading_utils.TaskChannel()
         tp.add_task(0, channel.wrap_task(lambda: 0))
         self.assertEqual(0, channel.pull())
예제 #13
0
 def test_double_close(self):
     pool = threading_utils.ThreadPool(1, 1, 0)
     pool.close()
     with self.assertRaises(threading_utils.ThreadPoolClosed):
         pool.close()
예제 #14
0
 def test_adding_tasks_after_close(self):
     pool = threading_utils.ThreadPool(1, 1, 0)
     pool.add_task(0, lambda: None)
     pool.close()
     with self.assertRaises(threading_utils.ThreadPoolClosed):
         pool.add_task(0, lambda: None)
예제 #15
0
 def setUp(self):
     super(ThreadPoolTest, self).setUp()
     self.thread_pool = threading_utils.ThreadPool(self.MIN_THREADS,
                                                   self.MAX_THREADS, 0)
예제 #16
0
 def test_passes_simple_value(self):
     with threading_utils.ThreadPool(1, 1, 0) as tp:
         channel = threading_utils.TaskChannel()
         tp.add_task(0, lambda: channel.send_result(0))
         self.assertEqual(0, channel.next())
예제 #17
0
 def test_passes_exception_value(self):
     with threading_utils.ThreadPool(1, 1, 0) as tp:
         channel = threading_utils.TaskChannel()
         tp.add_task(0, lambda: channel.send_result(Exception()))
         self.assertTrue(isinstance(channel.next(), Exception))
예제 #18
0
    def test_trace_multiple(self):
        # Starts parallel threads and trace parallel child processes simultaneously.
        # Some are started from 'tests' directory, others from this script's
        # directory. One trace fails. Verify everything still goes one.
        parallel = 8

        def trace(tracer, cmd, cwd, tracename):
            resultcode, output = tracer.trace(cmd, cwd, tracename, True)
            return (tracename, resultcode, output)

        with threading_utils.ThreadPool(parallel, parallel, 0) as pool:
            api = trace_inputs.get_api()
            with api.get_tracer(self.log) as tracer:
                pool.add_task(0, trace, tracer, self.get_child_command(False),
                              ROOT_DIR, 'trace1')
                pool.add_task(0, trace, tracer, self.get_child_command(True),
                              self.cwd, 'trace2')
                pool.add_task(0, trace, tracer, self.get_child_command(False),
                              ROOT_DIR, 'trace3')
                pool.add_task(0, trace, tracer, self.get_child_command(True),
                              self.cwd, 'trace4')
                # Have this one fail since it's started from the wrong directory.
                pool.add_task(0, trace, tracer, self.get_child_command(False),
                              self.cwd, 'trace5')
                pool.add_task(0, trace, tracer, self.get_child_command(True),
                              self.cwd, 'trace6')
                pool.add_task(0, trace, tracer, self.get_child_command(False),
                              ROOT_DIR, 'trace7')
                pool.add_task(0, trace, tracer, self.get_child_command(True),
                              self.cwd, 'trace8')
                trace_results = pool.join()

        def blacklist(f):
            return f.endswith(('.pyc', 'do_not_care.txt', '.git', '.svn'))

        actual_results = api.parse_log(self.log, blacklist, None)
        self.assertEqual(8, len(trace_results))
        self.assertEqual(8, len(actual_results))

        # Convert to dict keyed on the trace name, simpler to verify.
        trace_results = dict((i[0], i[1:]) for i in trace_results)
        actual_results = dict((x.pop('trace'), x) for x in actual_results)
        self.assertEqual(sorted(trace_results), sorted(actual_results))

        # It'd be nice to start different kinds of processes.
        expected_results = [
            self._gen_dict_full(),
            self._gen_dict_full_gyp(),
            self._gen_dict_full(),
            self._gen_dict_full_gyp(),
            self._gen_dict_wrong_path(),
            self._gen_dict_full_gyp(),
            self._gen_dict_full(),
            self._gen_dict_full_gyp(),
        ]
        self.assertEqual(len(expected_results), len(trace_results))

        # See the comment above about the trace that fails because it's started from
        # the wrong directory.
        busted = 4
        for index, key in enumerate(sorted(actual_results)):
            self.assertEqual('trace%d' % (index + 1), key)
            self.assertEqual(2, len(trace_results[key]))
            # returncode
            self.assertEqual(0 if index != busted else 2,
                             trace_results[key][0])
            # output
            self.assertEqual(actual_results[key]['output'],
                             trace_results[key][1])

            self.assertEqual(['output', 'results'],
                             sorted(actual_results[key]))
            results = actual_results[key]['results']
            results = results.strip_root(unicode(ROOT_DIR))
            actual = results.flatten()
            self.assertTrue(actual['root'].pop('pid'))
            if index != busted:
                self.assertTrue(actual['root']['children'][0].pop('pid'))
            self.assertEqual(expected_results[index], actual)
예제 #19
0
def yield_results(swarm_base_url, task_keys, timeout, max_threads,
                  print_status_updates, output_collector):
    """Yields swarming task results from the swarming server as (index, result).

  Duplicate shards are ignored. Shards are yielded in order of completion.
  Timed out shards are NOT yielded at all. Caller can compare number of yielded
  shards with len(task_keys) to verify all shards completed.

  max_threads is optional and is used to limit the number of parallel fetches
  done. Since in general the number of task_keys is in the range <=10, it's not
  worth normally to limit the number threads. Mostly used for testing purposes.

  output_collector is an optional instance of TaskOutputCollector that will be
  used to fetch files produced by a task from isolate server to the local disk.

  Yields:
    (index, result). In particular, 'result' is defined as the
    GetRunnerResults() function in services/swarming/server/test_runner.py.
  """
    number_threads = (min(max_threads, len(task_keys))
                      if max_threads else len(task_keys))
    should_stop = threading.Event()
    results_channel = threading_utils.TaskChannel()

    with threading_utils.ThreadPool(number_threads, number_threads, 0) as pool:
        try:
            # Adds a task to the thread pool to call 'retrieve_results' and return
            # the results together with shard_index that produced them (as a tuple).
            def enqueue_retrieve_results(shard_index, task_key):
                task_fn = lambda *args: (shard_index, retrieve_results(*args))
                pool.add_task(0, results_channel.wrap_task(task_fn),
                              swarm_base_url, shard_index, task_key, timeout,
                              should_stop, output_collector)

            # Enqueue 'retrieve_results' calls for each shard key to run in parallel.
            for shard_index, task_key in enumerate(task_keys):
                enqueue_retrieve_results(shard_index, task_key)

            # Wait for all of them to finish.
            shards_remaining = range(len(task_keys))
            active_task_count = len(task_keys)
            while active_task_count:
                shard_index, result = None, None
                try:
                    shard_index, result = results_channel.pull(
                        timeout=STATUS_UPDATE_INTERVAL)
                except threading_utils.TaskChannel.Timeout:
                    if print_status_updates:
                        print(
                            'Waiting for results from the following shards: %s'
                            % ', '.join(map(str, shards_remaining)))
                        sys.stdout.flush()
                    continue
                except Exception:
                    logging.exception(
                        'Unexpected exception in retrieve_results')

                # A call to 'retrieve_results' finished (successfully or not).
                active_task_count -= 1
                if not result:
                    logging.error(
                        'Failed to retrieve the results for a swarming key')
                    continue

                # Yield back results to the caller.
                assert shard_index in shards_remaining
                shards_remaining.remove(shard_index)
                yield shard_index, result

        finally:
            # Done or aborted with Ctrl+C, kill the remaining threads.
            should_stop.set()