コード例 #1
0
ファイル: driver.py プロジェクト: happyAnger6/OpenSpider
class QueueDriver:
    def __init__(self,**settings):
        self.settings = settings
        self._finished = Event()
        self._getters = collections.deque([])  # Futures.
        self._putters = collections.deque([])
        self.initialize(**settings)

    def initialize(self,**settings):
        pass

    def over(self):
        self._finished.set()

    def save(self):
        raise NotImplementedError()

    def get(self):
        raise NotImplementedError()

    def put(self):
        raise NotImplementedError()

    def join(self,timeout=None):
        return self._finished.wait(timeout)
コード例 #2
0
class ResolveTimeoutTestCase(AsyncHTTPTestCase):
    def setUp(self):
        self.cleanup_event = Event()
        test = self

        # Dummy Resolver subclass that never finishes.
        class BadResolver(Resolver):
            @gen.coroutine
            def resolve(self, *args, **kwargs):
                yield test.cleanup_event.wait()
                # Return something valid so the test doesn't raise during cleanup.
                return [(socket.AF_INET, ("127.0.0.1", test.get_http_port()))]

        super(ResolveTimeoutTestCase, self).setUp()
        self.http_client = SimpleAsyncHTTPClient(resolver=BadResolver())

    def get_app(self):
        return Application([url("/hello", HelloWorldHandler)])

    def test_resolve_timeout(self):
        with self.assertRaises(HTTPTimeoutError):
            self.fetch("/hello", connect_timeout=0.1, raise_error=True)

        # Let the hanging coroutine clean up after itself
        self.cleanup_event.set()
        self.io_loop.run_sync(lambda: gen.sleep(0))
コード例 #3
0
    def test_connect_timeout(self):
        timeout = 0.1

        cleanup_event = Event()
        test = self

        class TimeoutResolver(Resolver):
            async def resolve(self, *args, **kwargs):
                await cleanup_event.wait()
                # Return something valid so the test doesn't raise during shutdown.
                return [(socket.AF_INET, ("127.0.0.1", test.get_http_port()))]

        with closing(self.create_client(resolver=TimeoutResolver())) as client:
            with self.assertRaises(HTTPTimeoutError):
                yield client.fetch(
                    self.get_url("/hello"),
                    connect_timeout=timeout,
                    request_timeout=3600,
                    raise_error=True,
                )

        # Let the hanging coroutine clean up after itself. We need to
        # wait more than a single IOLoop iteration for the SSL case,
        # which logs errors on unexpected EOF.
        cleanup_event.set()
        yield gen.sleep(0.2)
コード例 #4
0
ファイル: handler.py プロジェクト: nzinov/phystech-seabattle
class GameHandler(WebSocketHandler):
    def __init__(self, *args, **kwargs):
        super(GameHandler, self).__init__(*args, **kwargs)
        self.game = None
        self.player = None
        self.answered = Event()

    def open(self):
        pass

    def auth(self, data):
        self.game = GameLoader.load(data["game"])
        self.player = data["player"]
        self.game.handlers[self.player] = self
        self.game.introduce(self.player)

    def on_message(self, message):
        data = Dumper.load(message)
        if data["action"] == "auth":
            self.auth(data)
        if self.game is None:
            self.close()
        if data["action"] == "move":
            self.game.take_action(self.player, data)
        if data["action"] == "answer":
            self.answered.set()

    def on_close(self):
        print("WebSocket closed")

    def check_origin(self, origin):
        return True
コード例 #5
0
ファイル: test_actor.py プロジェクト: tomMoral/distributed
    class Waiter(object):
        def __init__(self):
            self.event = Event()

        @gen.coroutine
        def set(self):
            self.event.set()

        @gen.coroutine
        def wait(self):
            yield self.event.wait()
コード例 #6
0
class ImageMutex():

    def __init__(self):
        self._mutex = Event()
        self._blocked = count()
        self._building_log = []
        self._exception = None

    @gen.coroutine
    def block(self):
        value = self._blocked.__next__()  # single bytecode operation
        if value:
            yield self._mutex.wait()
        return value

    def __enter__(self):
        if self._exception is not None:
            raise self._exception
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self._building_log = []
        if isinstance(exc_value, Exception):
            self._exception = exc_value
        self._mutex.set()

    def timeout_happened(self):
        self._exception = Exception('This image is too heavy to build')
        self._building_log = []

    def add_to_log(self, message, level=1):
        if not self._exception:
            self._building_log.append({
                'text': message,
                'level': level
            })

    @property
    def building_log(self):
        return self._building_log

    @property
    def last_exception(self):
        return self._exception
コード例 #7
0
ファイル: pubnub_tornado.py プロジェクト: pubnub/python
    def _perform_heartbeat_loop(self):
        if self._heartbeat_call is not None:
            # TODO: cancel call
            pass

        cancellation_event = Event()
        state_payload = self._subscription_state.state_payload()
        presence_channels = self._subscription_state.prepare_channel_list(False)
        presence_groups = self._subscription_state.prepare_channel_group_list(False)

        if len(presence_channels) == 0 and len(presence_groups) == 0:
            return

        try:
            envelope = yield self._pubnub.heartbeat() \
                .channels(presence_channels) \
                .channel_groups(presence_groups) \
                .state(state_payload) \
                .cancellation_event(cancellation_event) \
                .future()

            heartbeat_verbosity = self._pubnub.config.heartbeat_notification_options
            if envelope.status.is_error:
                if heartbeat_verbosity == PNHeartbeatNotificationOptions.ALL or \
                        heartbeat_verbosity == PNHeartbeatNotificationOptions.ALL:
                    self._listener_manager.announce_status(envelope.status)
            else:
                if heartbeat_verbosity == PNHeartbeatNotificationOptions.ALL:
                    self._listener_manager.announce_status(envelope.status)

        except PubNubTornadoException:
            pass
            # TODO: check correctness
            # if e.status is not None and e.status.category == PNStatusCategory.PNTimeoutCategory:
            #     self._start_subscribe_loop()
            # else:
            #     self._listener_manager.announce_status(e.status)
        except Exception as e:
            print(e)
        finally:
            cancellation_event.set()
コード例 #8
0
    def get_response(self, data, method, show_graphiql=False):
        query, variables, operation_name, id = self.get_graphql_params(
            self.request, data)

        execution_result = yield self.execute_graphql_request(
            method, query, variables, operation_name, show_graphiql)

        status_code = 200
        if execution_result:
            response = {}

            if getattr(execution_result, 'is_pending', False):
                event = Event()
                on_resolve = lambda *_: event.set()  # noqa
                execution_result.then(on_resolve).catch(on_resolve)
                yield event.wait()

            if hasattr(execution_result, 'get'):
                execution_result = execution_result.get()

            if execution_result.errors:
                response['errors'] = [
                    self.format_error(e) for e in execution_result.errors
                ]

            if execution_result.invalid:
                status_code = 400
            else:
                response['data'] = execution_result.data

            if self.batch:
                response['id'] = id
                response['status'] = status_code

            result = self.json_encode(response,
                                      pretty=self.pretty or show_graphiql)
        else:
            result = None

        raise Return((result, status_code))
コード例 #9
0
class RDKafkaDrain(object):
    """Implementation of IDrain that produces to a Kafka topic using librdkafka
    asynchronously. Backpressure is implemented with a tornado.queues.Queue.
    Expects an instance of confluent_kafka.Producer as self.sender.
    """

    def __init__(self, logger, loop, producer, topic, **kwargs):
        self.emitter = producer
        self.logger = logger
        self.loop = loop
        self.loop.spawn_callback(self._poll)
        self._completed = Queue()
        self._ignored_errors = set(kwargs.get('ignored_errors', []))
        # See: https://github.com/confluentinc/confluent-kafka-python/issues/147
        self._ignored_errors.update(TRANSIENT_ERRORS)
        self.metric_prefix = kwargs.get('metric_prefix', 'emitter')
        self.output_error = Event()
        self.sender_tag = 'sender:%s.%s' % (self.__class__.__module__,
                                            self.__class__.__name__)
        self.topic = topic
        self.state = RUNNING

    @gen.coroutine
    def close(self, timeout=INITIAL_TIMEOUT):
        try:
            self.state = CLOSING
            begin = datetime.utcnow()
            num_messages = len(self.emitter)
            elapsed = datetime.utcnow() - begin
            while num_messages > 0 and elapsed <= MAX_TIMEOUT:
                self.logger.info("Flushing send queue in %s/%s: %d",
                                 elapsed, MAX_TIMEOUT, num_messages)
                self.emitter.poll(0)
                num_messages = len(self.emitter)
                elapsed = datetime.utcnow() - begin
                timeout = min(timeout*2, MAX_TIMEOUT)
                yield gen.sleep(timeout.total_seconds())
            else:
                self.logger.error('Unable to flush messages; aborting')
        finally:
            self.state = CLOSED

    def emit_nowait(self, msg):
        self.logger.debug("Drain emitting")
        try:
            if self.topic == 'dynamic':
                assert isinstance(msg, KafkaMessage)
                self.emitter.produce(
                    msg.topic,
                    msg.value,
                    msg.key,
                    # This callback is executed in the librdkafka thread
                    callback=self._trampoline,
                )
            else:
                assert isinstance(msg, basestring)
                self.emitter.produce(
                    self.topic, msg,
                    # This callback is executed in the librdkafka thread
                    callback=self._trampoline,
                )
        except BufferError:
            raise QueueFull()

    @gen.coroutine
    def emit(self, msg, retry_timeout=INITIAL_TIMEOUT):
        while True:
            try:
                self.emit_nowait(msg)
            except QueueFull:
                yield gen.sleep(retry_timeout.total_seconds())
                retry_timeout = min(retry_timeout*2, MAX_TIMEOUT)

    @gen.coroutine
    def _poll(self, retry_timeout=INITIAL_TIMEOUT):
        """Infinite coroutine for draining the delivery report queue,
        with exponential backoff.
        """
        try:
            num_processed = self.emitter.poll(0)
            if num_processed > 0:
                self.logger.debug("Drain received ack for messages: %d",
                                  num_processed)
                retry_timeout = INITIAL_TIMEOUT
            else:
                self.logger.debug("Drain delivery report queue empty")
                # Retry with exponential backoff
                yield gen.sleep(retry_timeout.total_seconds())
                retry_timeout = min(retry_timeout*2, MAX_TIMEOUT)
        finally:
            self.loop.spawn_callback(self._poll, retry_timeout)

    @gen.coroutine
    def _on_track(self, err, kafka_msg):
        self.logger.debug('Received delivery notification: "%s", "%s"',
                          err, kafka_msg)
        if err:
            if err.code() in self._ignored_errors:
                self.logger.warning('Ignoring error: %s', err)
            else:
                self.logger.error('Error encountered, giving up: %s', err)
                self.output_error.set()

    def _trampoline(self, err, kafka_msg):
        # This is necessary, so that we trampoline from the librdkafka thread
        # back to the main Tornado thread:
        # add_callback() may be used to transfer control from other threads to
        # the IOLoop's thread.
        # It is safe to call this method from any thread at any time, except
        # from a signal handler. Note that this is the only method in IOLoop
        # that makes this thread-safety guarantee; all other interaction with
        # the IOLoop must be done from that IOLoop's thread.
        self.loop.add_callback(
            self._on_track, err, kafka_msg
        )
コード例 #10
0
class TornadoReconnectionManager(ReconnectionManager):
    def __init__(self, pubnub):
        self._cancelled_event = Event()
        super(TornadoReconnectionManager, self).__init__(pubnub)

    @gen.coroutine
    def _register_heartbeat_timer(self):
        self._cancelled_event.clear()

        while not self._cancelled_event.is_set():
            if self._pubnub.config.reconnect_policy == PNReconnectionPolicy.EXPONENTIAL:
                self._timer_interval = int(math.pow(2, self._connection_errors) - 1)
                if self._timer_interval > self.MAXEXPONENTIALBACKOFF:
                    self._timer_interval = self.MINEXPONENTIALBACKOFF
                    self._connection_errors = 1
                    logger.debug("timerInterval > MAXEXPONENTIALBACKOFF at: %s" % utils.datetime_now())
                elif self._timer_interval < 1:
                    self._timer_interval = self.MINEXPONENTIALBACKOFF
                logger.debug("timerInterval = %d at: %s" % (self._timer_interval, utils.datetime_now()))
            else:
                self._timer_interval = self.INTERVAL

            # >>> Wait given interval or cancel
            sleeper = tornado.gen.sleep(self._timer_interval)
            canceller = self._cancelled_event.wait()

            wi = tornado.gen.WaitIterator(canceller, sleeper)

            while not wi.done():
                try:
                    future = wi.next()
                    yield future
                except Exception as e:
                    # TODO: verify the error will not be eaten
                    logger.error(e)
                    raise
                else:
                    if wi.current_future == sleeper:
                        break
                    elif wi.current_future == canceller:
                        return
                    else:
                        raise Exception("unknown future raised")

            logger.debug("reconnect loop at: %s" % utils.datetime_now())

            # >>> Attempt to request /time/0 endpoint
            try:
                yield self._pubnub.time().result()
                self._connection_errors = 1
                self._callback.on_reconnect()
                logger.debug("reconnection manager stop due success time endpoint call: %s" % utils.datetime_now())
                break
            except Exception:
                if self._pubnub.config.reconnect_policy == PNReconnectionPolicy.EXPONENTIAL:
                    logger.debug("reconnect interval increment at: %s" % utils.datetime_now())
                    self._connection_errors += 1

    def start_polling(self):
        if self._pubnub.config.reconnect_policy == PNReconnectionPolicy.NONE:
            logger.warn("reconnection policy is disabled, please handle reconnection manually.")
            return

        self._pubnub.ioloop.spawn_callback(self._register_heartbeat_timer)

    def stop_polling(self):
        if self._cancelled_event is not None and not self._cancelled_event.is_set():
            self._cancelled_event.set()
コード例 #11
0
ファイル: index_manager.py プロジェクト: AppScale/appscale
class ProjectIndexManager(object):
  """ Keeps track of composite index definitions for a project. """

  def __init__(self, project_id, zk_client, index_manager, datastore_access):
    """ Creates a new ProjectIndexManager.

    Args:
      project_id: A string specifying a project ID.
      zk_client: A KazooClient.
      update_callback: A function that should be called with the project ID
        and index list every time the indexes get updated.
      index_manager: An IndexManager used for checking lock status.
      datastore_access: A DatastoreDistributed object.
    """
    self.project_id = project_id
    self.indexes_node = '/appscale/projects/{}/indexes'.format(self.project_id)
    self.active = True
    self.update_event = AsyncEvent()

    self._creation_times = {}
    self._index_manager = index_manager
    self._zk_client = zk_client
    self._ds_access = datastore_access

    self._zk_client.DataWatch(self.indexes_node, self._update_indexes_watch)

    # Since this manager can be used synchronously, ensure that the indexes
    # are populated for this IOLoop iteration.
    try:
      encoded_indexes = self._zk_client.get(self.indexes_node)[0]
    except NoNodeError:
      encoded_indexes = '[]'

    self.indexes = [DatastoreIndex.from_dict(self.project_id, index)
                    for index in json.loads(encoded_indexes)]

  @property
  def indexes_pb(self):
    if self._zk_client.state != KazooState.CONNECTED:
      raise IndexInaccessible('ZooKeeper connection is not active')

    return [index.to_pb() for index in self.indexes]

  @gen.coroutine
  def apply_definitions(self):
    """ Populate composite indexes that are not marked as ready yet. """
    try:
      yield self.update_event.wait()
      self.update_event.clear()
      if not self._index_manager.admin_lock.is_acquired or not self.active:
        return

      logger.info(
        'Applying composite index definitions for {}'.format(self.project_id))

      for index in self.indexes:
        if index.ready:
          continue

        # Wait until all clients have either timed out or received the new index
        # definition. This prevents entities from being added without entries
        # while the index is being rebuilt.
        creation_time = self._creation_times.get(index.id, time.time())
        consensus = creation_time + (self._zk_client._session_timeout / 1000.0)
        yield gen.sleep(max(consensus - time.time(), 0))

        yield self._ds_access.update_composite_index(
          self.project_id, index.to_pb())
        logger.info('Index {} is now ready'.format(index.id))
        self._mark_index_ready(index.id)

      logging.info(
        'All composite indexes for {} are ready'.format(self.project_id))
    finally:
      IOLoop.current().spawn_callback(self.apply_definitions)

  def delete_index_definition(self, index_id):
    """ Remove a definition from a project's list of configured indexes.

    Args:
      index_id: An integer specifying an index ID.
    """
    try:
      encoded_indexes, znode_stat = self._zk_client.get(self.indexes_node)
    except NoNodeError:
      # If there are no index definitions, there is nothing to do.
      return

    node_version = znode_stat.version
    indexes = [DatastoreIndex.from_dict(self.project_id, index)
               for index in json.loads(encoded_indexes)]

    encoded_indexes = json.dumps([index.to_dict() for index in indexes
                                  if index.id != index_id])
    self._zk_client.set(self.indexes_node, encoded_indexes,
                        version=node_version)

  def _mark_index_ready(self, index_id):
    """ Updates the index metadata to reflect the new state of the index.

    Args:
      index_id: An integer specifying an index ID.
    """
    try:
      encoded_indexes, znode_stat = self._zk_client.get(self.indexes_node)
      node_version = znode_stat.version
    except NoNodeError:
      # If for some reason the index no longer exists, there's nothing to do.
      return

    existing_indexes = [DatastoreIndex.from_dict(self.project_id, index)
                        for index in json.loads(encoded_indexes)]
    for existing_index in existing_indexes:
      if existing_index.id == index_id:
        existing_index.ready = True

    indexes_dict = [index.to_dict() for index in existing_indexes]
    self._zk_client.set(self.indexes_node, json.dumps(indexes_dict),
                        version=node_version)

  @gen.coroutine
  def _update_indexes(self, encoded_indexes):
    """ Handles changes to the list of a project's indexes.

    Args:
      encoded_indexes: A string containing index node data.
    """
    encoded_indexes = encoded_indexes or '[]'
    self.indexes = [DatastoreIndex.from_dict(self.project_id, index)
                    for index in json.loads(encoded_indexes)]

    # Mark when indexes are defined so they can be backfilled later.
    self._creation_times.update(
      {index.id: time.time() for index in self.indexes
       if not index.ready and index.id not in self._creation_times})

    self.update_event.set()

  def _update_indexes_watch(self, encoded_indexes, znode_stat):
    """ Handles updates to the project's indexes node.

    Args:
      encoded_indexes: A string containing index node data.
      znode_stat: A kazoo.protocol.states.ZnodeStat object.
    """
    if not self.active:
      return False

    IOLoop.current().add_callback(self._update_indexes, encoded_indexes)
コード例 #12
0
class SQSDrain(object):
    """Implementation of IDrain that writes to an AWS SQS queue.
    """

    def __init__(self, logger, loop, sqs_client,
                 metric_prefix='emitter'):
        self.emitter = sqs_client
        self.logger = logger
        self.loop = loop
        self.metric_prefix = metric_prefix
        self.output_error = Event()
        self.state = RUNNING
        self.sender_tag = 'sender:%s.%s' % (self.__class__.__module__,
                                            self.__class__.__name__)
        self._send_queue = Queue()
        self._should_flush_queue = Event()
        self._flush_handle = None
        self.loop.spawn_callback(self._onSend)

    @gen.coroutine
    def _flush_send_batch(self, batch_size):
        send_batch = [
            self._send_queue.get_nowait()
            for pos in range(min(batch_size, self.emitter.max_messages))
        ]
        try:
            response = yield self.emitter.send_message_batch(*send_batch)
        except SQSError as err:
            self.logger.exception('Error encountered flushing data to SQS: %s',
                                  err)
            self.output_error.set()
            for msg in send_batch:
                self._send_queue.put_nowait(msg)
        else:
            if response.Failed:
                self.output_error.set()
                for req in response.Failed:
                    self.logger.error('Message failed to send: %s', req.Id)
                    self._send_queue.put_nowait(req)

    @gen.coroutine
    def _onSend(self):
        respawn = True
        while respawn:
            qsize = self._send_queue.qsize()
            # This will keep flushing until clear,
            # including items that show up in between flushes
            while qsize > 0:
                try:
                    yield self._flush_send_batch(qsize)
                except Exception as err:
                    self.logger.exception(err)
                    self.output_error.set()
                qsize = self._send_queue.qsize()
            # We've cleared the backlog, remove any possible future flush
            if self._flush_handle:
                self.loop.remove_timeout(self._flush_handle)
                self._flush_handle = None
            self._should_flush_queue.clear()
            yield self._should_flush_queue.wait()

    @gen.coroutine
    def close(self, timeout=None):
        self.state = CLOSING
        yield self._send_queue.join(timeout)

    def emit_nowait(self, msg):
        if self._send_queue.qsize() >= self.emitter.max_messages:
            # Signal flush
            self._should_flush_queue.set()
            raise QueueFull()
        elif self._flush_handle is None:
            # Ensure we flush messages at least by MAX_TIMEOUT
            self._flush_handle = self.loop.add_timeout(
                MAX_TIMEOUT,
                lambda: self._should_flush_queue.set(),
            )
        self.logger.debug("Drain emitting")
        self._send_queue.put_nowait(msg)

    @gen.coroutine
    def emit(self, msg, timeout=None):
        if self._send_queue.qsize() >= self.emitter.max_messages:
            # Signal flush
            self._should_flush_queue.set()
        elif self._flush_handle is None:
            # Ensure we flush messages at least by MAX_TIMEOUT
            self._flush_handle = self.loop.add_timeout(
                MAX_TIMEOUT,
                lambda: self._should_flush_queue.set(),
            )
        yield self._send_queue.put(msg, timeout)
コード例 #13
0
class WebUpdater:
    def __init__(self, umgr, config):
        self.umgr = umgr
        self.server = umgr.server
        self.notify_update_response = umgr.notify_update_response
        self.repo = config.get('repo').strip().strip("/")
        self.name = self.repo.split("/")[-1]
        if hasattr(config, "get_name"):
            self.name = config.get_name().split()[-1]
        self.path = os.path.realpath(os.path.expanduser(config.get("path")))
        self.version = self.remote_version = self.dl_url = "?"
        self.etag = None
        self.init_evt = Event()
        self.refresh_condition = None
        self._get_local_version()
        logging.info(f"\nInitializing Client Updater: '{self.name}',"
                     f"\nversion: {self.version}"
                     f"\npath: {self.path}")

    def _get_local_version(self):
        version_path = os.path.join(self.path, ".version")
        if os.path.isfile(os.path.join(self.path, ".version")):
            with open(version_path, "r") as f:
                v = f.read()
            self.version = v.strip()

    async def check_initialized(self, timeout=None):
        if self.init_evt.is_set():
            return
        if timeout is not None:
            timeout = IOLoop.current().time() + timeout
        await self.init_evt.wait(timeout)

    async def refresh(self):
        if self.refresh_condition is None:
            self.refresh_condition = Condition()
        else:
            self.refresh_condition.wait()
            return
        try:
            self._get_local_version()
            await self._get_remote_version()
        except Exception:
            logging.exception("Error Refreshing Client")
        self.init_evt.set()
        self.refresh_condition.notify_all()
        self.refresh_condition = None

    async def _get_remote_version(self):
        # Remote state
        url = f"https://api.github.com/repos/{self.repo}/releases/latest"
        try:
            result = await self.umgr.github_api_request(url, etag=self.etag)
        except Exception:
            logging.exception(f"Client {self.repo}: Github Request Error")
            result = {}
        if result is None:
            # No change, update not necessary
            return
        self.etag = result.get('etag', None)
        self.remote_version = result.get('name', "?")
        release_assets = result.get('assets', [{}])[0]
        self.dl_url = release_assets.get('browser_download_url', "?")
        logging.info(f"Github client Info Received:\nRepo: {self.name}\n"
                     f"Local Version: {self.version}\n"
                     f"Remote Version: {self.remote_version}\n"
                     f"url: {self.dl_url}")

    async def update(self, *args):
        await self.check_initialized(20.)
        if self.refresh_condition is not None:
            # wait for refresh if in progess
            self.refresh_condition.wait()
        if self.remote_version == "?":
            await self.refresh()
            if self.remote_version == "?":
                raise self.server.error(
                    f"Client {self.repo}: Unable to locate update")
        if self.dl_url == "?":
            raise self.server.error(
                f"Client {self.repo}: Invalid download url")
        if self.version == self.remote_version:
            # Already up to date
            return
        if os.path.isdir(self.path):
            shutil.rmtree(self.path)
        os.mkdir(self.path)
        self.notify_update_response(f"Downloading Client: {self.name}")
        archive = await self.umgr.http_download_request(self.dl_url)
        with zipfile.ZipFile(io.BytesIO(archive)) as zf:
            zf.extractall(self.path)
        self.version = self.remote_version
        version_path = os.path.join(self.path, ".version")
        if not os.path.exists(version_path):
            with open(version_path, "w") as f:
                f.write(self.version)
        self.notify_update_response(f"Client Update Finished: {self.name}",
                                    is_complete=True)

    def get_update_status(self):
        return {
            'name': self.name,
            'version': self.version,
            'remote_version': self.remote_version
        }
コード例 #14
0
ファイル: spec.py プロジェクト: cnanakos/distributed
class ProcessInterface:
    """
    An interface for Scheduler and Worker processes for use in SpecCluster

    This interface is responsible to submit a worker or scheduler process to a
    resource manager like Kubernetes, Yarn, or SLURM/PBS/SGE/...
    It should implement the methods below, like ``start`` and ``close``
    """
    def __init__(self, scheduler=None, name=None):
        self.address = getattr(self, "address", None)
        self.external_address = None
        self.lock = asyncio.Lock()
        self.status = "created"
        self._event_finished = Event()

    def __await__(self):
        async def _():
            async with self.lock:
                if self.status == "created":
                    await self.start()
                    assert self.status == "running"
            return self

        return _().__await__()

    async def start(self):
        """ Submit the process to the resource manager

        For workers this doesn't have to wait until the process actually starts,
        but can return once the resource manager has the request, and will work
        to make the job exist in the future

        For the scheduler we will expect the scheduler's ``.address`` attribute
        to be avaialble after this completes.
        """
        self.status = "running"

    async def close(self):
        """ Close the process

        This will be called by the Cluster object when we scale down a node,
        but only after we ask the Scheduler to close the worker gracefully.
        This method should kill the process a bit more forcefully and does not
        need to worry about shutting down gracefully
        """
        self.status = "closed"
        self._event_finished.set()

    async def finished(self):
        """ Wait until the server has finished """
        await self._event_finished.wait()

    def __repr__(self):
        return "<%s: status=%s>" % (type(self).__name__, self.status)

    async def __aenter__(self):
        await self
        return self

    async def __aexit__(self, *args, **kwargs):
        await self.close()
コード例 #15
0
ファイル: core.py プロジェクト: amosonn/distributed
class ConnectionPool(object):
    """ A maximum sized pool of Tornado IOStreams

    This provides a connect method that mirrors the normal distributed.connect
    method, but provides connection sharing and tracks connection limits.

    This object provides an ``rpc`` like interface::

        >>> rpc = ConnectionPool(limit=512)
        >>> scheduler = rpc('127.0.0.1:8786')
        >>> workers = [rpc(ip=ip, port=port) for ip, port in ...]

        >>> info = yield scheduler.identity()

    It creates enough streams to satisfy concurrent connections to any
    particular address::

        >>> a, b = yield [scheduler.who_has(), scheduler.has_what()]

    It reuses existing streams so that we don't have to continuously reconnect.

    It also maintains a stream limit to avoid "too many open file handle"
    issues.  Whenever this maximum is reached we clear out all idling streams.
    If that doesn't do the trick then we wait until one of the occupied streams
    closes.
    """
    def __init__(self, limit=512):
        self.open = 0
        self.active = 0
        self.limit = limit
        self.available = defaultdict(set)
        self.occupied = defaultdict(set)
        self.event = Event()

    def __str__(self):
        return "<ConnectionPool: open=%d, active=%d>" % (self.open,
                self.active)

    __repr__ = __str__

    def __call__(self, arg=None, ip=None, port=None, addr=None):
        """ Cached rpc objects """
        ip, port = ip_port_from_args(arg=arg, addr=addr, ip=ip, port=port)
        return RPCCall(ip, port, self)

    @gen.coroutine
    def connect(self, ip, port, timeout=3):
        if self.available.get((ip, port)):
            stream = self.available[ip, port].pop()
            self.active += 1
            self.occupied[ip, port].add(stream)
            raise gen.Return(stream)

        while self.open >= self.limit:
            self.event.clear()
            self.collect()
            yield self.event.wait()

        self.open += 1
        stream = yield connect(ip=ip, port=port, timeout=timeout)
        stream.set_close_callback(lambda: self.on_close(ip, port, stream))
        self.active += 1
        self.occupied[ip, port].add(stream)

        if self.open >= self.limit:
            self.event.clear()

        raise gen.Return(stream)

    def on_close(self, ip, port, stream):
        self.open -= 1

        if stream in self.available[ip, port]:
            self.available[ip, port].remove(stream)
        if stream in self.occupied[ip, port]:
            self.occupied[ip, port].remove(stream)
            self.active -= 1

        if self.open <= self.limit:
            self.event.set()

    def collect(self):
        logger.info("Collecting unused streams.  open: %d, active: %d",
                    self.open, self.active)
        for k, streams in list(self.available.items()):
            for stream in streams:
                stream.close()
コード例 #16
0
class DebugpyClient:
    def __init__(self, log, debugpy_stream, event_callback):
        self.log = log
        self.debugpy_stream = debugpy_stream
        self.event_callback = event_callback
        self.message_queue = DebugpyMessageQueue(self._forward_event, self.log)
        self.debugpy_host = '127.0.0.1'
        self.debugpy_port = -1
        self.routing_id = None
        self.wait_for_attach = True
        self.init_event = Event()
        self.init_event_seq = -1

    def _get_endpoint(self):
        host, port = self.get_host_port()
        return 'tcp://' + host + ':' + str(port)

    def _forward_event(self, msg):
        if msg['event'] == 'initialized':
            self.init_event.set()
            self.init_event_seq = msg['seq']
        self.event_callback(msg)

    def _send_request(self, msg):
        if self.routing_id is None:
            self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID)
        content = jsonapi.dumps(
            msg,
            default=json_default,
            ensure_ascii=False,
            allow_nan=False,
        )
        content_length = str(len(content))
        buf = (DebugpyMessageQueue.HEADER + content_length +
               DebugpyMessageQueue.SEPARATOR).encode('ascii')
        buf += content
        self.log.debug("DEBUGPYCLIENT:")
        self.log.debug(self.routing_id)
        self.log.debug(buf)
        self.debugpy_stream.send_multipart((self.routing_id, buf))

    async def _wait_for_response(self):
        # Since events are never pushed to the message_queue
        # we can safely assume the next message in queue
        # will be an answer to the previous request
        return await self.message_queue.get_message()

    async def _handle_init_sequence(self):
        # 1] Waits for initialized event
        await self.init_event.wait()

        # 2] Sends configurationDone request
        configurationDone = {
            'type': 'request',
            'seq': int(self.init_event_seq) + 1,
            'command': 'configurationDone'
        }
        self._send_request(configurationDone)

        # 3]  Waits for configurationDone response
        await self._wait_for_response()

        # 4] Waits for attachResponse and returns it
        attach_rep = await self._wait_for_response()
        return attach_rep

    def get_host_port(self):
        if self.debugpy_port == -1:
            socket = self.debugpy_stream.socket
            socket.bind_to_random_port('tcp://' + self.debugpy_host)
            self.endpoint = socket.getsockopt(
                zmq.LAST_ENDPOINT).decode('utf-8')
            socket.unbind(self.endpoint)
            index = self.endpoint.rfind(':')
            self.debugpy_port = self.endpoint[index + 1:]
        return self.debugpy_host, self.debugpy_port

    def connect_tcp_socket(self):
        self.debugpy_stream.socket.connect(self._get_endpoint())
        self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID)

    def disconnect_tcp_socket(self):
        self.debugpy_stream.socket.disconnect(self._get_endpoint())
        self.routing_id = None
        self.init_event = Event()
        self.init_event_seq = -1
        self.wait_for_attach = True

    def receive_dap_frame(self, frame):
        self.message_queue.put_tcp_frame(frame)

    async def send_dap_request(self, msg):
        self._send_request(msg)
        if self.wait_for_attach and msg['command'] == 'attach':
            rep = await self._handle_init_sequence()
            self.wait_for_attach = False
            return rep
        else:
            rep = await self._wait_for_response()
            self.log.debug('DEBUGPYCLIENT - returning:')
            self.log.debug(rep)
            return rep
コード例 #17
0
ファイル: pubnub_tornado.py プロジェクト: pubnub/python
class SubscribeListener(SubscribeCallback):
    def __init__(self):
        self.connected = False
        self.connected_event = Event()
        self.disconnected_event = Event()
        self.presence_queue = Queue()
        self.message_queue = Queue()
        self.error_queue = Queue()

    def status(self, pubnub, status):
        if utils.is_subscribed_event(status) and not self.connected_event.is_set():
            self.connected_event.set()
        elif utils.is_unsubscribed_event(status) and not self.disconnected_event.is_set():
            self.disconnected_event.set()
        elif status.is_error():
            self.error_queue.put_nowait(status.error_data.exception)

    def message(self, pubnub, message):
        self.message_queue.put(message)

    def presence(self, pubnub, presence):
        self.presence_queue.put(presence)

    @tornado.gen.coroutine
    def _wait_for(self, coro):
        error = self.error_queue.get()
        wi = tornado.gen.WaitIterator(coro, error)

        while not wi.done():
            result = yield wi.next()

            if wi.current_future == coro:
                raise gen.Return(result)
            elif wi.current_future == error:
                raise result
            else:
                raise Exception("Unexpected future resolved: %s" % str(wi.current_future))

    @tornado.gen.coroutine
    def wait_for_connect(self):
        if not self.connected_event.is_set():
            yield self._wait_for(self.connected_event.wait())
        else:
            raise Exception("instance is already connected")

    @tornado.gen.coroutine
    def wait_for_disconnect(self):
        if not self.disconnected_event.is_set():
            yield self._wait_for(self.disconnected_event.wait())
        else:
            raise Exception("instance is already disconnected")

    @tornado.gen.coroutine
    def wait_for_message_on(self, *channel_names):
        channel_names = list(channel_names)
        while True:
            try: # NOQA
                env = yield self._wait_for(self.message_queue.get())
                if env.channel in channel_names:
                    raise tornado.gen.Return(env)
                else:
                    continue
            finally:
                self.message_queue.task_done()

    @tornado.gen.coroutine
    def wait_for_presence_on(self, *channel_names):
        channel_names = list(channel_names)
        while True:
            try:
                try:
                    env = yield self._wait_for(self.presence_queue.get())
                except: # NOQA E722 pylint: disable=W0702
                    break
                if env.channel in channel_names:
                    raise tornado.gen.Return(env)
                else:
                    continue
            finally:
                self.presence_queue.task_done()
コード例 #18
0
ファイル: corr_monitoring_loop.py プロジェクト: ska-sa/corr2
class MonitoringLoop(object):
    def __init__(self, check_time, fx_correlator_object):

        self.instrument = fx_correlator_object
        self.hosts = self.instrument.fhosts + self.instrument.xhosts
        self.selected_host = None
        self.host_index = 0
        self.num_hosts = len(self.hosts)
        self.num_fhosts = len(self.instrument.fhosts)
        self.num_xhosts = len(self.instrument.xhosts)
        # self.num_bhosts = len(self.instrument.bhosts)

        # check config file if bhosts or xhosts

        if check_time == -1:
            self.check_time = float(self.instrument.configd['FxCorrelator']['monitor_loop_time'])
        else:
            self.check_time = check_time

        # set up periodic engine monitoring
        self.instrument_monitoring_loop_enabled = IOLoopEvent()
        self.instrument_monitoring_loop_enabled.clear()
        self.instrument_monitoring_loop_cb = None

        self.f_eng_board_monitoring_dict_prev = {}
        self.x_eng_board_monitoring_dict_prev = {}
        self.b_eng_board_monitoring_dict_prev = {}

        self.disabled_fhosts = []
        self.disabled_xhosts = []
        self.disabled_bhosts = []

        # some other useful bits of info
        self.n_chans = self.instrument.n_chans
        self.chans_per_xhost = self.n_chans / self.num_xhosts

    def start(self):
        """
        Start the monitoring loop
        :return: none
        """
        self._instrument_monitoring_loop_timer_start(check_time=self.check_time)

    def stop(self):
        """
        Stop the monitoring loop
        :return: none
        """
        self._instrument_monitoring_loop_timer_stop()

    def _instrument_monitoring_loop_timer_start(self, check_time=None):
        """
        Set up periodic check of various instrument elements
        :param check_time: the interval, in seconds, at which to check
        :return:
        """

        if not IOLoop.current()._running:
            raise RuntimeError('IOLoop not running, this will not work')

        self.instrument.logger.info('instrument_monitoring_loop for instrument %s '
                                    'set up with a period '
                                    'of %.2f seconds' % (self.instrument.descriptor, self.check_time))

        if self.instrument_monitoring_loop_cb is not None:
            self.instrument_monitoring_loop_cb.stop()
        self.instrument_monitoring_loop_cb = PeriodicCallback(
            self._instrument_monitoring_loop, check_time * 1000)

        self.instrument_monitoring_loop_enabled.set()
        self.instrument_monitoring_loop_cb.start()
        self.instrument.logger.info('Instrument Monitoring Loop Timer '
                                    'Started @ '
                         '%s' % time.ctime())

    def _instrument_monitoring_loop_timer_stop(self):
        """
        Disable the periodic instrument monitoring loop
        :return:
        """

        if self.instrument_monitoring_loop_cb is not None:
            self.instrument_monitoring_loop_cb.stop()
        self.instrument_monitoring_loop_cb = None
        self.instrument_monitoring_loop_enabled.clear()
        self.instrument.logger.info('Instrument Monitoring Loop Timer Halted @ '
                         '%s' % time.ctime())

    # TODO: use functools to pass this callback function with parameters
    def _instrument_monitoring_loop(self, check_fhosts=True, check_xhosts=True, check_bhosts=False):
        """
        Perform various checks periodically.
        :param corner_turner_check: enable periodic checking of the corner-
        turner; will disable F-engine output on overflow
        :param coarse_delay_check: enable periodic checking of the coarse
        delay
        :param vacc_check: enable periodic checking of the vacc
        turner
        :return:
        """

        # TODO: refactor this to handle checking of all host types
        # TODO: run all tests on everything?
        # TODO: figure out how to selectively test pieces
        # select a new host
        host = self.hosts[self.host_index]
        board_monitoring_dict_current = {}

        # check host type
        if host.host_type == 'fhost':
            if check_fhosts:
                board_monitoring_dict_current[host] = self._get_fhost_status(host=host)

                # check error counters if all fhosts have status
                if len(self.f_eng_board_monitoring_dict_prev) == self.num_fhosts:
                    self._check_fhost_errors(board_monitoring_dict_current, host)

                self.f_eng_board_monitoring_dict_prev[host] = \
                    board_monitoring_dict_current[host]

        elif host.host_type == 'xhost' or host.host_type == 'bhost':
            if check_xhosts:
                board_monitoring_dict_current[host] = self._get_xhost_status(host=host)

                # check errs if all xhosts have status
                if len(self.x_eng_board_monitoring_dict_prev) == self.num_xhosts:
                    self._check_xhost_errors(board_monitoring_dict_current, host)

                self.x_eng_board_monitoring_dict_prev[host] = \
                    board_monitoring_dict_current[host]

        # TODO: how to handle bhosts and xhosts?
        elif host.host_type == 'bhost':
            if check_bhosts:
                pass

        # increment board counter, move to the next board the next time
        # loop runs

        if self.host_index == self.num_hosts - 1:
            if not self.disabled_fhosts and not self.disabled_xhosts and not self.disabled_bhosts:
                self.instrument.logger.info('Monitoring loop run ok. All hosts checked - no hosts disabled')
            else:
                self.instrument.logger.warning(
                    'Monitoring loop run ok. All hosts checked - some hosts disabled')

                # list the disabled fhosts
                if self.disabled_fhosts:
                    self.instrument.logger.warning(
                        'corr2 monitor loop: disabled f-hosts: %s' % [
                            'fhost%d:%s:%s' % (
                            disabled_fhost.fhost_index, disabled_fhost.host,
                            [feng.input.name for feng in disabled_fhost.fengines])
                            for disabled_fhost in self.disabled_fhosts])

                # list the disabled xhosts
                if self.disabled_xhosts:
                    self.instrument.logger.warning(
                        'corr2 monitor loop: disabled x-hosts: %s'
                        % ['xhost%d:%s:%d-%d' % (
                        disabled_xhost.index, disabled_xhost.host,
                        self.instrument.xops.board_ids[
                            disabled_xhost.host] * self.chans_per_xhost,
                        (self.instrument.xops.board_ids[
                             disabled_xhost.host] + 1) * self.chans_per_xhost - 1)
                       for disabled_xhost in self.disabled_xhosts])

            # reset the host counter to start checking again
            self.host_index = 0
        else:
            self.host_index += 1

        #if self.disabled_bhosts:
        #    self.instrument.logger.warning('corr2 monitor loop: disabled b-hosts: %s' % [
        #        '%d:%s' % (disabled_bhost.index, disabled_bhost.host) for disabled_bhost in self.disabled_bhosts])

        return True

    def _get_fhost_status(self, host, corner_turner_check=True,
                          coarse_delay_check=True, rx_reorder_check=True):
        """
        Checks the f-hosts for errors

        :return:
        """
        status = {}
        # check ct & cd
        if corner_turner_check:
            # perform corner-turner check
            ct_status = host.get_ct_status()
            status['corner_turner'] = ct_status
        if coarse_delay_check:
            # perform coarse delay check
            cd_status = host.get_cd_status()
            status['coarse_delay'] = cd_status

        # check feng rx reorder
        if rx_reorder_check:
            feng_rx_reorder_status = host.get_rx_reorder_status()
            status['feng_rx_reorder'] = feng_rx_reorder_status

        return status

    def _check_fhost_errors(self, board_monitoring_dict_current, host):
        """

        :param board_monitoring_dict_current:
        :param host:
        :return:
        """
        action = {'disable_output': 0,
                  'reenable_output': 0}

        coarse_delay_action = self._check_feng_coarse_delay_errs(
            board_monitoring_dict_current[host], host)

        corner_turner_action = self._check_feng_corner_turn_errs(
            board_monitoring_dict_current[host], host)

        rx_reorder_action = self._check_feng_rx_reorder_errs(
            board_monitoring_dict_current[host], host)

        # consolidate the action dictionary - only reenable if all
        # errors are cleared
        if coarse_delay_action['disable_output'] \
                or corner_turner_action['disable_output'] \
                or rx_reorder_action['disable_output']:
            action['disable_output'] = True
        elif coarse_delay_action['reenable_output'] \
                or corner_turner_action['reenable_output'] \
                or rx_reorder_action['reenable_output']:
            action['reenable_output'] = True
        else:
            # no action required
            pass

        # take appropriate action on board

        if action['disable_output']:
            # keep track of which boards have been disabled
            if host not in self.disabled_fhosts:
                self.disabled_fhosts.append(host)
                self._disable_feng_ouput(fhost=host)

        elif action['reenable_output']:
            # after checking, we already know that this board was
            # disabled prior
            self._renable_feng_output(fhost=host)
            # remove the board from the list of disabled boards
            self.disabled_fhosts.remove(host)
        else:
            # no action taken
            pass

    def _get_xhost_status(self, host, xeng_rx_reorder_check=False, xeng_hmc_reorder_check=True, xeng_vacc_check=True):
        """

        :param host:
        :param xeng_rx_reorder_check:
        :param xeng_hmc_reorder_check:
        :param xeng_vacc_check:
        :return:
        """

        status = {}
        # check xeng rx reorder

        if xeng_rx_reorder_check:
            xeng_rx_reorder_status = host.get_rx_reorder_status()
            status['xeng_rx_reorder'] = xeng_rx_reorder_status

        # check xeng hmc reorder
        if xeng_hmc_reorder_check:
            xeng_hmc_reorder_status = host.get_hmc_reorder_status()
            status['xeng_hmc_reorder'] = xeng_hmc_reorder_status

        # check xeng vacc
        if xeng_vacc_check:
            xeng_vacc_status = host.get_vacc_status()
            status['xeng_vacc'] = xeng_vacc_status

        return status

    def _check_xhost_errors(self, board_monitoring_dict_current, host):
        """

        :param board_monitoring_dict_current:
        :param host:
        :return:
        """
        action = {'disable_output': 0,
                  'reenable_output': 0}

        hmc_reorder_action = self._check_xeng_hmc_reorder_errs(
            board_monitoring_dict_current[host], host)

        vacc_action = self._check_xeng_vacc_errs(
            board_monitoring_dict_current[host], host)

        # consolidate the action dictionary - only reenable if all
        # errors are cleared
        if hmc_reorder_action['disable_output'] \
                or vacc_action['disable_output']:
            action['disable_output'] = True
        elif hmc_reorder_action['reenable_output'] \
                or vacc_action['reenable_output']:
            action['reenable_output'] = True
        else:
            # no action required
            pass

        # take appropriate action on board

        if action['disable_output']:
            # keep track of which boards have been disabled
            if host not in self.disabled_xhosts:
                self.disabled_xhosts.append(host)
                self._disable_xeng_ouput(xhost=host)

        elif action['reenable_output']:
            # after checking, we already know that this board was
            # disabled prior
            self._renable_xeng_output(xhost=host)
            # remove the board from the list of disabled boards
            self.disabled_xhosts.remove(host)
        else:
            # no action taken
            pass

        self.x_eng_board_monitoring_dict_prev[host] = \
            board_monitoring_dict_current[host]

    def _check_xeng_rx_reorder_errs(self, x_eng_status_dict, xhost):
        """

        :param x_eng_status_dict:
        :param xhost:
        :return:
        """
        raise NotImplementedError

    def _check_xeng_hmc_reorder_errs(self, x_eng_status_dict, xhost):
        """

        :param x_eng_status_dict:
        :param xhost:
        :return:
        """
        action_dict = {'disable_output': False, 'reenable_output': False}

        if x_eng_status_dict.has_key('xeng_hmc_reorder'):

            hmc_reorder_dict = x_eng_status_dict['xeng_hmc_reorder']

            # counters
            # check error counters, first check if a previous status
            # dict exists
            # flags
            if not hmc_reorder_dict['init_done']:
                self.instrument.logger.warning(
                    'xhost %s hmc reorder has init errors' %
                     xhost.host)
                action_dict['disable_output'] = True

            if not hmc_reorder_dict['post_ok']:
                self.instrument.logger.warning('xhost %s hmc reorder '
                                               'has post errors' % xhost.host)
                action_dict['disable_output'] = True

            # check error counters
            if self.x_eng_board_monitoring_dict_prev:
                hmc_reorder_dict_prev = \
                    self.x_eng_board_monitoring_dict_prev[xhost][
                        'xeng_hmc_reorder']
#TODO Ignore CRC errors on the HMCs for now...
#                if hmc_reorder_dict['err_cnt_link2'] != hmc_reorder_dict_prev[
#                    'err_cnt_link2']:
#                    self.instrument.logger.warning(
#                        'xhost %s hmc reorder has errors on link 2' %
#                         xhost.host)
#                    action_dict['disable_output'] = True
#                if hmc_reorder_dict['err_cnt_link3'] != hmc_reorder_dict_prev[
#                    'err_cnt_link3']:
#                    self.instrument.logger.warning(
#                        'xhost %s hmc reorder has errors on link 3' %
#                        xhost.host)
#                    action_dict['disable_output'] = True
                if hmc_reorder_dict['lnk2_nrdy_err_cnt'] != hmc_reorder_dict_prev[
                    'lnk2_nrdy_err_cnt']:
                    self.instrument.logger.warning(
                        'xhost %s hmc reorder has link 2 nrdy errors' %
                        xhost.host)
                    action_dict['disable_output'] = True
                if hmc_reorder_dict['lnk3_nrdy_err_cnt'] != hmc_reorder_dict_prev[
                    'lnk3_nrdy_err_cnt']:
                    self.instrument.logger.warning(
                        'xhost %s hmc reorder has link 3 nrdy errors' %
                        xhost.host)
                    action_dict['disable_output'] = True
                if hmc_reorder_dict['mcnt_timeout_cnt'] != hmc_reorder_dict_prev[
                    'mcnt_timeout_cnt']:
                    self.instrument.logger.warning(
                        'xhost %s hmc reorder has mcnt timeout errors' %
                        xhost.host)
                    action_dict['disable_output'] = True
                if hmc_reorder_dict['ts_err_cnt'] != hmc_reorder_dict_prev[
                    'ts_err_cnt']:
                    self.instrument.logger.warning(
                        'xhost %s hmc reorder has timestamp errors' %
                        xhost.host)
                    action_dict['disable_output'] = True

            # if no errors, check if board was disabled, then flag for reenable
        if not action_dict['disable_output']:
            # no errors detected, no need to disable any boards
            # but was the board previously disabled?
            if xhost in self.disabled_xhosts:
                # the errors on the board have been cleared
                action_dict['reenable_output'] = True

        return action_dict

    def _check_xeng_vacc_errs(self, x_eng_status_dict, xhost):
        """

        :param x_eng_status_dict:
        :param xhost:
        :return:
        """
        action_dict = {'disable_output': False, 'reenable_output': False}

        if x_eng_status_dict.has_key('xeng_vacc'):

            vacc_dict = x_eng_status_dict['xeng_vacc']

            # counters
            # check error counters, first check if a previous status
            # dict exists

            # check error counters
            if self.x_eng_board_monitoring_dict_prev:
                vacc_dict_prev = \
                    self.x_eng_board_monitoring_dict_prev[xhost][
                        'xeng_vacc']

                for vacc in range(len(vacc_dict)):
                    # there are four vaccs per xhost

                    if vacc_dict[vacc]['err_cnt'] != vacc_dict_prev[vacc]['err_cnt']:
                        self.instrument.logger.warning(
                            'xhost %s vacc has errors' %
                            xhost.host)
                        action_dict['disable_output'] = True

            # if no errors, check if board was disabled, then flag for reenable
        if not action_dict['disable_output']:
            # no errors detected, no need to disable any boards
            # but was the board previously disabled?
            if xhost in self.disabled_xhosts:
                # the errors on the board have been cleared
                action_dict['reenable_output'] = True

        return action_dict

    def _check_feng_rx_reorder_errs(self, f_eng_status_dict, fhost):
        """

        :param f_eng_status_dict:
        :param fhost:
        :return:
        """
        action_dict = {'disable_output': False, 'reenable_output': False}

        if f_eng_status_dict.has_key('feng_rx_reorder'):

            rx_reorder_dict = f_eng_status_dict['feng_rx_reorder']

            # counters
            # check error counters, first check if a previous status
            # dict exists

            if self.f_eng_board_monitoring_dict_prev:
                rx_reorder_dict_prev = \
                    self.f_eng_board_monitoring_dict_prev[fhost][
                        'feng_rx_reorder']
                # check error counters
                if rx_reorder_dict['overflow_err_cnt'] != rx_reorder_dict_prev[
                    'overflow_err_cnt']:
                    self.instrument.logger.warning(
                        'fhost %s rx reorder has overflow errors' %
                        fhost.host)
                    action_dict['disable_output'] = True
                if rx_reorder_dict['receive_err_cnt'] != rx_reorder_dict_prev[
                    'receive_err_cnt']:
                    self.instrument.logger.warning(
                        'fhost %s rx reorder has receive errors' %
                        fhost.host)
                    action_dict['disable_output'] = True
                if rx_reorder_dict['relock_err_cnt'] != rx_reorder_dict_prev[
                    'relock_err_cnt']:
                    self.instrument.logger.warning(
                        'fhost %s rx reorder has relock errors' %
                        fhost.host)
                    action_dict['disable_output'] = True
                if rx_reorder_dict['timestep_err_cnt'] != rx_reorder_dict_prev[
                                                            'timestep_err_cnt']:
                    self.instrument.logger.warning(
                        'fhost %s rx reorder has timestep errors' %
                        fhost.host)
                    action_dict['disable_output'] = True

            # if no errors, check if board was disabled, then flag for reenable
        if not action_dict['disable_output']:
            # no errors detected, no need to disable any boards
            # but was the board previously disabled?
            if fhost in self.disabled_fhosts:
                # the errors on the board have been cleared
                action_dict['reenable_output'] = True

        return action_dict

    def _check_feng_coarse_delay_errs(self, f_eng_status_dict, fhost):
        """
        Check f-engines for any coarse delay errors
        :return:
        """
        action_dict = {'disable_output': False, 'reenable_output': False}

        if f_eng_status_dict.has_key('coarse_delay'):

            cd_dict = f_eng_status_dict['coarse_delay']
            # flags
            if not cd_dict['hmc_init']:
                self.instrument.logger.warning(
                    'fhost %s coarse delay has hmc init errors' %
                    fhost.host)
                action_dict['disable_output'] = True

            if not cd_dict['hmc_post']:
                self.instrument.logger.warning('fhost %s coarse delay '
                                               'has hmc post '
                                               'errors' %
                                               fhost.host)
                action_dict['disable_output'] = True

                # check error counters, first check if a previous status
                # dict exists

            if self.f_eng_board_monitoring_dict_prev:
                cd_dict_prev = \
                    self.f_eng_board_monitoring_dict_prev[fhost][
                        'coarse_delay']
                # check error counters
                if cd_dict['reord_jitter_err_cnt_pol0'] != cd_dict_prev[
                    'reord_jitter_err_cnt_pol0'] or cd_dict[
                        'reord_jitter_err_cnt_pol1'] != cd_dict_prev[
                    'reord_jitter_err_cnt_pol1']:
                    self.instrument.logger.warning(
                        'fhost %s coarse delay has reorder jitter errors' %
                        fhost.host)
                    action_dict['disable_output'] = True
                if cd_dict['hmc_overflow_err_cnt_pol0'] != cd_dict_prev[
                    'hmc_overflow_err_cnt_pol0'] or cd_dict[
                    'hmc_overflow_err_cnt_pol1'] != cd_dict_prev[
                                                    'hmc_overflow_err_cnt_pol1']:
                    self.instrument.logger.warning('fhost %s coarse '
                                                   'delay has '
                                                   'overflow errors' %
                                                   fhost.host)
                    action_dict['disable_output'] = True

        # if no errors, check if board was disabled, then flag for reenable
        if not action_dict['disable_output']:
            # no errors detected, no need to disable any boards
            # but was the board previously disabled?
            if fhost in self.disabled_fhosts:
                # the errors on the board have been cleared
                action_dict['reenable_output'] = True

        return action_dict

    def _check_feng_corner_turn_errs(self, f_eng_status_dict, fhost):
        """
        Check f-engines for any corner-turner errors

        :return:
        """

        action_dict = {'disable_output': False, 'reenable_output': False}

        if f_eng_status_dict.has_key('corner_turner'):

            ct_dict = f_eng_status_dict['corner_turner']
            # flags
            if not ct_dict['hmc_init_pol0'] or not ct_dict['hmc_init_pol1']:
                self.instrument.logger.warning('fhost %s corner-turner '
                                               'has hmc '
                                               'init errors' %
                                               fhost.host)
                action_dict['disable_output'] = True
            if not ct_dict['hmc_post_pol0'] or not ct_dict['hmc_post_pol1']:
                self.instrument.logger.warning('fhost %s corner-turner '
                                               'has hmc post '
                                               'errors' %
                                               fhost.host)
                action_dict['disable_output'] = True

            # check error counters, first check if a previous status
            # dict exists
            if self.f_eng_board_monitoring_dict_prev:
                ct_dict_prev = self.f_eng_board_monitoring_dict_prev[fhost][
                    'corner_turner']
                # check error counters
                if ct_dict['bank_err_cnt_pol0'] != ct_dict_prev[
                    'bank_err_cnt_pol0'] or ct_dict[ \
                        'bank_err_cnt_pol1'] != ct_dict_prev[
                    'bank_err_cnt_pol1']:
                    self.instrument.logger.warning('fhost %s '
                                                   'corner-turner has bank errors' %
                                                   fhost.host)
                    action_dict['disable_output'] = True
                if ct_dict['fifo_full_err_cnt'] != ct_dict_prev[
                    'fifo_full_err_cnt']:
                    self.instrument.logger.warning('fhost %s '
                                                   'corner-turner has fifo full '
                                                   'errors' % fhost.host)
                    action_dict['disable_output'] = True
                if ct_dict['rd_go_err_cnt'] != ct_dict_prev['rd_go_err_cnt']:
                    self.instrument.logger.warning('fhost %s '
                                                   'corner-turner has read go '
                                                   'errors' % fhost.host)
                    action_dict['disable_output'] = True
                if ct_dict['obuff_bank_err_cnt'] != ct_dict_prev[
                    'obuff_bank_err_cnt']:
                    self.instrument.logger.warning('fhost %s '
                                                   'corner-turner has '
                                                   'obuff errors' % fhost.host)
                    action_dict['disable_output'] = True
                if ct_dict['hmc_overflow_err_cnt_pol0'] != ct_dict_prev[
                    'hmc_overflow_err_cnt_pol0'] or ct_dict[
                    'hmc_overflow_err_cnt_pol1'] != ct_dict_prev[
                    'hmc_overflow_err_cnt_pol1']:
                    self.instrument.logger.warning('fhost %s '
                                                   'corner-turner has '
                                                   'overflow errors' %
                                                   fhost.host)
                    action_dict['disable_output'] = True

        # if no errors, check if board was disabled, then flag for reenable
        if not action_dict['disable_output']:
            # no errors detected, no need to disable any boards
            # but was the board previously disabled?
            if fhost in self.disabled_fhosts:
                # the errors on the board have been cleared
                action_dict['reenable_output'] = True

        return action_dict

    def _disable_feng_ouput(self, fhost):
        """
        Disables the output from an f-engine
        :param fhost: the host board with f-engines to disable
        :return:
        """

        fhost.tx_disable()
        self.instrument.logger.warning('fhost%d %s %s output disabled!' %
                                       (fhost.fhost_index,
                                        fhost.host,
                                        [feng.input.name for feng in fhost.fengines]
                                        ))

    def _renable_feng_output(self, fhost):
        """
        Reenables the output from an f-engine
        :param fhost: the host board with f-engines to reenable
        :return:
        """

        fhost.tx_enable()
        self.instrument.logger.info('fhost%d %s %s output reenabled!' %
                                    (fhost.fhost_index,
                                     fhost.host,
                                     [feng.input.name for feng in
                                      fhost.fengines]
                                     ))

    def _disable_xeng_ouput(self, xhost):
        """
        Disables the output from an f-engine
        :param xhost: the host board with f-engines to disable
        :return:
        """

        xhost.registers.control.write(gbe_txen=False)
        self.instrument.logger.warning('xhost%d %s %s output disabled!' %
                                      (xhost.index, xhost.host,
                                      (self.instrument.xops.board_ids[xhost.host] * self.chans_per_xhost,
                                       (self.instrument.xops.board_ids[xhost.host] + 1) * self.chans_per_xhost - 1)
                                      ))

    def _renable_xeng_output(self, xhost):
        """
        Reenables the output from an f-engine
        :param xhost: the host board with f-engines to reenable
        :return:
        """

        xhost.registers.control.write(gbe_txen=True)
        self.instrument.logger.info('xhost%d %s %s output reenabled!' %
                                    (xhost.index, xhost.host,
                                     (self.instrument.xops.board_ids[
                                          xhost.host] * self.chans_per_xhost,
                                      (self.instrument.xops.board_ids[
                                           xhost.host] + 1) * self.chans_per_xhost - 1)
                                     ))

    def get_bad_fhosts(self):
        """
        Returns a list of bad known fhosts that are currently disables
        :return: list of bad fhosts (hostnames)
        """

        return self.disabled_fhosts

    def get_bad_xhosts(self):
        """
        Returns a list of bad known xhosts that are currently disables
        :return: list of bad xhosts (hostnames)
        """

        return self.disabled_xhosts

    def get_bad_bhosts(self):
        """
        Returns a list of bad known bhosts that are currently disables
        :return: list of bad bhosts (hostnames)
        """

        return self.disabled_bhosts
コード例 #19
0
ファイル: executor.py プロジェクト: freeman-lab/distributed
class Executor(object):
    """ Distributed executor with data dependencies

    This executor resembles executors in concurrent.futures but also allows
    Futures within submit/map calls.

    Provide center address on initialization

    >>> executor = Executor(('127.0.0.1', 8787))  # doctest: +SKIP

    Use ``submit`` method like normal

    >>> a = executor.submit(add, 1, 2)  # doctest: +SKIP
    >>> b = executor.submit(add, 10, 20)  # doctest: +SKIP

    Additionally, provide results of submit calls (futures) to further submit
    calls:

    >>> c = executor.submit(add, a, b)  # doctest: +SKIP

    This allows for the dynamic creation of complex dependencies.
    """
    def __init__(self, center, start=True, delete_batch_time=1, loop=None):
        self.center = coerce_to_rpc(center)
        self.futures = dict()
        self.refcount = defaultdict(lambda: 0)
        self.loop = loop or IOLoop()
        self.scheduler = Scheduler(center, delete_batch_time=delete_batch_time)

        if start:
            self.start()

    def start(self):
        """ Start scheduler running in separate thread """
        if hasattr(self, '_loop_thread'):
            return
        from threading import Thread
        self._loop_thread = Thread(target=self.loop.start)
        self._loop_thread.daemon = True
        _global_executors.add(self)
        self._loop_thread.start()
        sync(self.loop, self._start)

    @gen.coroutine
    def _start(self):
        yield self.scheduler._sync_center()
        self._scheduler_start_event = Event()
        self.coroutines = [self.scheduler.start(), self.report()]
        _global_executors.add(self)
        yield self._scheduler_start_event.wait()
        logger.debug("Started scheduling coroutines. Synchronized")

    @property
    def scheduler_queue(self):
        return self.scheduler.scheduler_queue

    @property
    def report_queue(self):
        return self.scheduler.report_queue

    def __enter__(self):
        if not self.loop._running:
            self.start()
        return self

    def __exit__(self, type, value, traceback):
        self.shutdown()

    def _inc_ref(self, key):
        self.refcount[key] += 1

    def _dec_ref(self, key):
        self.refcount[key] -= 1
        if self.refcount[key] == 0:
            del self.refcount[key]
            self._release_key(key)

    def _release_key(self, key):
        """ Release key from distributed memory """
        logger.debug("Release key %s", key)
        if key in self.futures:
            self.futures[key]['event'].clear()
            del self.futures[key]
        self.loop.add_callback(self.scheduler_queue.put_nowait,
                {'op': 'release-held-data', 'key': key})

    @gen.coroutine
    def report(self):
        """ Listen to scheduler """
        while True:
            msg = yield self.report_queue.get()
            if msg['op'] == 'start':
                self._scheduler_start_event.set()
            if msg['op'] == 'close':
                break
            if msg['op'] == 'key-in-memory':
                if msg['key'] in self.futures:
                    self.futures[msg['key']]['status'] = 'finished'
                    self.futures[msg['key']]['event'].set()
            if msg['op'] == 'lost-data':
                if msg['key'] in self.futures:
                    self.futures[msg['key']]['status'] = 'lost'
                    self.futures[msg['key']]['event'].clear()
            if msg['op'] == 'task-erred':
                if msg['key'] in self.futures:
                    self.futures[msg['key']]['status'] = 'error'
                    self.futures[msg['key']]['exception'] = msg['exception']
                    self.futures[msg['key']]['traceback'] = msg['traceback']
                    self.futures[msg['key']]['event'].set()

    @gen.coroutine
    def _shutdown(self, fast=False):
        """ Send shutdown signal and wait until scheduler completes """
        self.loop.add_callback(self.report_queue.put_nowait,
                               {'op': 'close'})
        self.loop.add_callback(self.scheduler_queue.put_nowait,
                               {'op': 'close'})
        if self in _global_executors:
            _global_executors.remove(self)
        if not fast:
            yield self.coroutines

    def shutdown(self):
        """ Send shutdown signal and wait until scheduler terminates """
        self.loop.add_callback(self.report_queue.put_nowait, {'op': 'close'})
        self.loop.add_callback(self.scheduler_queue.put_nowait, {'op': 'close'})
        self.loop.stop()
        self._loop_thread.join()
        if self in _global_executors:
            _global_executors.remove(self)

    def submit(self, func, *args, **kwargs):
        """ Submit a function application to the scheduler

        Parameters
        ----------
        func: callable
        *args:
        **kwargs:
        pure: bool (defaults to True)
            Whether or not the function is pure.  Set ``pure=False`` for
            impure functions like ``np.random.random``.
        workers: set, iterable of sets
            A set of worker hostnames on which computations may be performed.
            Leave empty to default to all workers (common case)

        Examples
        --------
        >>> c = executor.submit(add, a, b)  # doctest: +SKIP

        Returns
        -------
        Future

        See Also
        --------
        distributed.executor.Executor.submit:
        """
        if not callable(func):
            raise TypeError("First input to submit must be a callable function")

        key = kwargs.pop('key', None)
        pure = kwargs.pop('pure', True)
        workers = kwargs.pop('workers', None)

        if key is None:
            if pure:
                key = funcname(func) + '-' + tokenize(func, kwargs, *args)
            else:
                key = funcname(func) + '-' + next(tokens)

        if key in self.futures:
            return Future(key, self)

        if kwargs:
            task = (apply, func, args, kwargs)
        else:
            task = (func,) + args

        if workers is not None:
            restrictions = {key: workers}
        else:
            restrictions = {}

        logger.debug("Submit %s(...), %s", funcname(func), key)
        self.loop.add_callback(self.scheduler_queue.put_nowait,
                                        {'op': 'update-graph',
                                         'dsk': {key: task},
                                         'keys': [key],
                                         'restrictions': restrictions})

        return Future(key, self)

    def map(self, func, *iterables, **kwargs):
        """ Map a function on a sequence of arguments

        Arguments can be normal objects or Futures

        Parameters
        ----------
        func: callable
        iterables: Iterables
        pure: bool (defaults to True)
            Whether or not the function is pure.  Set ``pure=False`` for
            impure functions like ``np.random.random``.
        workers: set, iterable of sets
            A set of worker hostnames on which computations may be performed.
            Leave empty to default to all workers (common case)

        Examples
        --------
        >>> L = executor.map(func, sequence)  # doctest: +SKIP

        Returns
        -------
        list of futures

        See also
        --------
        distributed.executor.Executor.submit
        """
        pure = kwargs.pop('pure', True)
        workers = kwargs.pop('workers', None)
        if not callable(func):
            raise TypeError("First input to map must be a callable function")
        iterables = [list(it) for it in iterables]
        if pure:
            keys = [funcname(func) + '-' + tokenize(func, kwargs, *args)
                    for args in zip(*iterables)]
        else:
            uid = str(uuid.uuid4())
            keys = [funcname(func) + '-' + uid + '-' + next(tokens)
                    for i in range(min(map(len, iterables)))]

        if not kwargs:
            dsk = {key: (func,) + args
                   for key, args in zip(keys, zip(*iterables))}
        else:
            dsk = {key: (apply, func, args, kwargs)
                   for key, args in zip(keys, zip(*iterables))}

        if isinstance(workers, (list, set)):
            if workers and isinstance(first(workers), (list, set)):
                if len(workers) != len(keys):
                    raise ValueError("You only provided %d worker restrictions"
                    " for a sequence of length %d" % (len(workers), len(keys)))
                restrictions = dict(zip(keys, workers))
            else:
                restrictions = {key: workers for key in keys}
        elif workers is None:
            restrictions = {}
        else:
            raise TypeError("Workers must be a list or set of workers or None")

        logger.debug("map(%s, ...)", funcname(func))
        self.loop.add_callback(self.scheduler_queue.put_nowait,
                                        {'op': 'update-graph',
                                         'dsk': dsk,
                                         'keys': keys,
                                         'restrictions': restrictions})

        return [Future(key, self) for key in keys]

    @gen.coroutine
    def _gather(self, futures):
        futures2, keys = unpack_remotedata(futures)
        keys = list(keys)

        while True:
            logger.debug("Waiting on futures to clear before gather")
            yield All([self.futures[key]['event'].wait() for key in keys
                                                    if key in self.futures])
            exceptions = [self.futures[key]['exception'] for key in keys
                          if self.futures[key]['status'] == 'error']
            if exceptions:
                raise exceptions[0]
            try:
                data = yield _gather(self.center, keys)
            except KeyError as e:
                logger.debug("Couldn't gather keys %s", e)
                self.loop.add_callback(self.scheduler_queue.put_nowait,
                                                {'op': 'missing-data',
                                                 'missing': e.args})
                for key in e.args:
                    self.futures[key]['event'].clear()
            else:
                break

        data = dict(zip(keys, data))

        result = pack_data(futures2, data)
        raise gen.Return(result)

    def gather(self, futures):
        """ Gather futures from distributed memory

        Accepts a future or any nested core container of futures

        Examples
        --------
        >>> from operator import add  # doctest: +SKIP
        >>> e = Executor('127.0.0.1:8787')  # doctest: +SKIP
        >>> x = e.submit(add, 1, 2)  # doctest: +SKIP
        >>> e.gather(x)  # doctest: +SKIP
        3
        >>> e.gather([x, [x], x])  # doctest: +SKIP
        [3, [3], 3]
        """
        return sync(self.loop, self._gather, futures)

    @gen.coroutine
    def _scatter(self, data, workers=None):
        if not self.scheduler.ncores:
            raise ValueError("No workers yet found.  "
                             "Try syncing with center.\n"
                             "  e.sync_center()")
        ncores = workers if workers is not None else self.scheduler.ncores
        remotes, who_has, nbytes = yield scatter_to_workers(
                                            self.center, ncores, data)
        if isinstance(remotes, list):
            remotes = [Future(r.key, self) for r in remotes]
        elif isinstance(remotes, dict):
            remotes = {k: Future(v.key, self) for k, v in remotes.items()}
        self.loop.add_callback(self.scheduler_queue.put_nowait,
                                        {'op': 'update-data',
                                         'who-has': who_has,
                                         'nbytes': nbytes})
        while not all(k in self.scheduler.who_has for k in who_has):
            yield gen.sleep(0.001)
        raise gen.Return(remotes)

    def scatter(self, data, workers=None):
        """ Scatter data into distributed memory

        Accepts a list of data elements or dict of key-value pairs

        Optionally provide a set of workers to constrain the scatter.  Specify
        workers as hostname/port pairs, i.e.  ('127.0.0.1', 8787).
        Default port is 8788.

        Examples
        --------
        >>> e = Executor('127.0.0.1:8787')  # doctest: +SKIP
        >>> e.scatter([1, 2, 3])  # doctest: +SKIP
        [RemoteData<center=127.0.0.1:8787, key=d1d26ff2-8...>,
         RemoteData<center=127.0.0.1:8787, key=d1d26ff2-8...>,
         RemoteData<center=127.0.0.1:8787, key=d1d26ff2-8...>]
        >>> e.scatter({'x': 1, 'y': 2, 'z': 3})  # doctest: +SKIP
        {'x': RemoteData<center=127.0.0.1:8787, key=x>,
         'y': RemoteData<center=127.0.0.1:8787, key=y>,
         'z': RemoteData<center=127.0.0.1:8787, key=z>}

        >>> e.scatter([1, 2, 3], workers=[('hostname', 8788)])  # doctest: +SKIP
        """
        return sync(self.loop, self._scatter, data, workers=workers)

    @gen.coroutine
    def _get(self, dsk, keys, restrictions=None, raise_on_error=True):
        flatkeys = list(flatten([keys]))
        futures = {key: Future(key, self) for key in flatkeys}

        self.loop.add_callback(self.scheduler_queue.put_nowait,
                                        {'op': 'update-graph',
                                         'dsk': dsk,
                                         'keys': flatkeys,
                                         'restrictions': restrictions or {}})

        packed = pack_data(keys, futures)
        if raise_on_error:
            result = yield self._gather(packed)
        else:
            try:
                result = yield self._gather(packed)
                result = 'OK', result
            except Exception as e:
                result = 'error', e
        raise gen.Return(result)

    def get(self, dsk, keys, **kwargs):
        """ Gather futures from distributed memory

        Parameters
        ----------
        dsk: dict
        keys: object, or nested lists of objects
        restrictions: dict (optional)
            A mapping of {key: {set of worker hostnames}} that restricts where
            jobs can take place

        Examples
        --------
        >>> from operator import add  # doctest: +SKIP
        >>> e = Executor('127.0.0.1:8787')  # doctest: +SKIP
        >>> e.get({'x': (add, 1, 2)}, 'x')  # doctest: +SKIP
        3
        """
        status, result = sync(self.loop, self._get, dsk, keys,
                              raise_on_error=False, **kwargs)

        if status == 'error':
            raise result
        else:
            return result

    @gen.coroutine
    def _restart(self):
        logger.debug("Sending shutdown signal to workers")
        nannies = yield self.center.nannies()
        for addr in nannies:
            self.loop.add_callback(self.scheduler_queue.put_nowait,
                    {'op': 'worker-failed', 'worker': addr, 'heal': False})

        logger.debug("Sending kill signal to nannies")
        nannies = [rpc(ip=ip, port=n_port)
                   for (ip, w_port), n_port in nannies.items()]
        yield All([nanny.kill() for nanny in nannies])

        while self.scheduler.ncores:
            yield gen.sleep(0.01)

        yield self._shutdown(fast=True)

        events = [d['event'] for d in self.futures.values()]
        self.futures.clear()
        for e in events:
            e.set()

        yield All([nanny.instantiate(close=True) for nanny in nannies])

        logger.info("Restarting executor")
        self.scheduler.report_queue = Queue()
        self.scheduler.scheduler_queue = Queue()
        self.scheduler.delete_queue = Queue()
        yield self._start()

    def restart(self):
        """ Restart the distributed network

        This kills all active work, deletes all data on the network, and
        restarts the worker processes.
        """
        return sync(self.loop, self._restart)

    @gen.coroutine
    def _upload_file(self, filename):
        with open(filename, 'rb') as f:
            data = f.read()
        _, fn = os.path.split(filename)
        d = yield self.center.broadcast(msg={'op': 'upload_file',
                                             'filename': fn,
                                             'data': data})

        assert all(len(data) == v for v in d.values())

    def upload_file(self, filename):
        """ Upload local package to workers

        Parameters
        ----------
        filename: string
            Filename of .py file to send to workers
        """
        return sync(self.loop, self._upload_file, filename)
コード例 #20
0
ファイル: pubnub_tornado.py プロジェクト: pubnub/python
class TornadoSubscriptionManager(SubscriptionManager):
    def __init__(self, pubnub_instance):

        subscription_manager = self

        self._message_queue = Queue()
        self._consumer_event = Event()
        self._cancellation_event = Event()
        self._subscription_lock = Semaphore(1)
        # self._current_request_key_object = None
        self._heartbeat_periodic_callback = None
        self._reconnection_manager = TornadoReconnectionManager(pubnub_instance)

        super(TornadoSubscriptionManager, self).__init__(pubnub_instance)
        self._start_worker()

        class TornadoReconnectionCallback(ReconnectionCallback):
            def on_reconnect(self):
                subscription_manager.reconnect()

                pn_status = PNStatus()
                pn_status.category = PNStatusCategory.PNReconnectedCategory
                pn_status.error = False

                subscription_manager._subscription_status_announced = True
                subscription_manager._listener_manager.announce_status(pn_status)

        self._reconnection_listener = TornadoReconnectionCallback()
        self._reconnection_manager.set_reconnection_listener(self._reconnection_listener)

    def _set_consumer_event(self):
        self._consumer_event.set()

    def _message_queue_put(self, message):
        self._message_queue.put(message)

    def _start_worker(self):
        self._consumer = TornadoSubscribeMessageWorker(self._pubnub,
                                                       self._listener_manager,
                                                       self._message_queue,
                                                       self._consumer_event)
        run = stack_context.wrap(self._consumer.run)
        self._pubnub.ioloop.spawn_callback(run)

    def reconnect(self):
        self._should_stop = False
        self._pubnub.ioloop.spawn_callback(self._start_subscribe_loop)
        # self._register_heartbeat_timer()

    def disconnect(self):
        self._should_stop = True
        self._stop_heartbeat_timer()
        self._stop_subscribe_loop()

    @tornado.gen.coroutine
    def _start_subscribe_loop(self):
        self._stop_subscribe_loop()

        yield self._subscription_lock.acquire()

        self._cancellation_event.clear()

        combined_channels = self._subscription_state.prepare_channel_list(True)
        combined_groups = self._subscription_state.prepare_channel_group_list(True)

        if len(combined_channels) == 0 and len(combined_groups) == 0:
            return

        envelope_future = Subscribe(self._pubnub) \
            .channels(combined_channels).channel_groups(combined_groups) \
            .timetoken(self._timetoken).region(self._region) \
            .filter_expression(self._pubnub.config.filter_expression) \
            .cancellation_event(self._cancellation_event) \
            .future()

        canceller_future = self._cancellation_event.wait()

        wi = tornado.gen.WaitIterator(envelope_future, canceller_future)

        # iterates 2 times: one for result one for cancelled
        while not wi.done():
            try:
                result = yield wi.next()
            except Exception as e:
                # TODO: verify the error will not be eaten
                logger.error(e)
                raise
            else:
                if wi.current_future == envelope_future:
                    e = result
                elif wi.current_future == canceller_future:
                    return
                else:
                    raise Exception("Unexpected future resolved: %s" % str(wi.current_future))

                if e.is_error():
                    # 599 error doesn't works - tornado use this status code
                    # for a wide range of errors, for ex:
                    # HTTP Server Error (599): [Errno -2] Name or service not known
                    if e.status is not None and e.status.category == PNStatusCategory.PNTimeoutCategory:
                        self._pubnub.ioloop.spawn_callback(self._start_subscribe_loop)
                        return

                    logger.error("Exception in subscribe loop: %s" % str(e))

                    if e.status is not None and e.status.category == PNStatusCategory.PNAccessDeniedCategory:
                        e.status.operation = PNOperationType.PNUnsubscribeOperation

                    self._listener_manager.announce_status(e.status)

                    self._reconnection_manager.start_polling()
                    self.disconnect()
                    return
                else:
                    self._handle_endpoint_call(e.result, e.status)

                    self._pubnub.ioloop.spawn_callback(self._start_subscribe_loop)

            finally:
                self._cancellation_event.set()
                yield tornado.gen.moment
                self._subscription_lock.release()
                self._cancellation_event.clear()
                break

    def _stop_subscribe_loop(self):
        if self._cancellation_event is not None and not self._cancellation_event.is_set():
            self._cancellation_event.set()

    def _stop_heartbeat_timer(self):
        if self._heartbeat_periodic_callback is not None:
            self._heartbeat_periodic_callback.stop()

    def _register_heartbeat_timer(self):
        super(TornadoSubscriptionManager, self)._register_heartbeat_timer()
        self._heartbeat_periodic_callback = PeriodicCallback(
            stack_context.wrap(self._perform_heartbeat_loop),
            self._pubnub.config.heartbeat_interval * TornadoSubscriptionManager.HEARTBEAT_INTERVAL_MULTIPLIER,
            self._pubnub.ioloop)
        self._heartbeat_periodic_callback.start()

    @tornado.gen.coroutine
    def _perform_heartbeat_loop(self):
        if self._heartbeat_call is not None:
            # TODO: cancel call
            pass

        cancellation_event = Event()
        state_payload = self._subscription_state.state_payload()
        presence_channels = self._subscription_state.prepare_channel_list(False)
        presence_groups = self._subscription_state.prepare_channel_group_list(False)

        if len(presence_channels) == 0 and len(presence_groups) == 0:
            return

        try:
            envelope = yield self._pubnub.heartbeat() \
                .channels(presence_channels) \
                .channel_groups(presence_groups) \
                .state(state_payload) \
                .cancellation_event(cancellation_event) \
                .future()

            heartbeat_verbosity = self._pubnub.config.heartbeat_notification_options
            if envelope.status.is_error:
                if heartbeat_verbosity == PNHeartbeatNotificationOptions.ALL or \
                        heartbeat_verbosity == PNHeartbeatNotificationOptions.ALL:
                    self._listener_manager.announce_status(envelope.status)
            else:
                if heartbeat_verbosity == PNHeartbeatNotificationOptions.ALL:
                    self._listener_manager.announce_status(envelope.status)

        except PubNubTornadoException:
            pass
            # TODO: check correctness
            # if e.status is not None and e.status.category == PNStatusCategory.PNTimeoutCategory:
            #     self._start_subscribe_loop()
            # else:
            #     self._listener_manager.announce_status(e.status)
        except Exception as e:
            print(e)
        finally:
            cancellation_event.set()

    @tornado.gen.coroutine
    def _send_leave(self, unsubscribe_operation):
        envelope = yield Leave(self._pubnub) \
            .channels(unsubscribe_operation.channels) \
            .channel_groups(unsubscribe_operation.channel_groups).future()
        self._listener_manager.announce_status(envelope.status)
コード例 #21
0
ファイル: pubnub_tornado.py プロジェクト: pubnub/python
class TornadoReconnectionManager(ReconnectionManager):
    def __init__(self, pubnub):
        self._cancelled_event = Event()
        super(TornadoReconnectionManager, self).__init__(pubnub)

    @gen.coroutine
    def _register_heartbeat_timer(self):
        self._cancelled_event.clear()

        while not self._cancelled_event.is_set():
            if self._pubnub.config.reconnect_policy == PNReconnectionPolicy.EXPONENTIAL:
                self._timer_interval = int(math.pow(2, self._connection_errors) - 1)
                if self._timer_interval > self.MAXEXPONENTIALBACKOFF:
                    self._timer_interval = self.MINEXPONENTIALBACKOFF
                    self._connection_errors = 1
                    logger.debug("timerInterval > MAXEXPONENTIALBACKOFF at: %s" % utils.datetime_now())
                elif self._timer_interval < 1:
                    self._timer_interval = self.MINEXPONENTIALBACKOFF
                logger.debug("timerInterval = %d at: %s" % (self._timer_interval, utils.datetime_now()))
            else:
                self._timer_interval = self.INTERVAL

            # >>> Wait given interval or cancel
            sleeper = tornado.gen.sleep(self._timer_interval)
            canceller = self._cancelled_event.wait()

            wi = tornado.gen.WaitIterator(canceller, sleeper)

            while not wi.done():
                try:
                    future = wi.next()
                    yield future
                except Exception as e:
                    # TODO: verify the error will not be eaten
                    logger.error(e)
                    raise
                else:
                    if wi.current_future == sleeper:
                        break
                    elif wi.current_future == canceller:
                        return
                    else:
                        raise Exception("unknown future raised")

            logger.debug("reconnect loop at: %s" % utils.datetime_now())

            # >>> Attempt to request /time/0 endpoint
            try:
                yield self._pubnub.time().result()
                self._connection_errors = 1
                self._callback.on_reconnect()
                logger.debug("reconnection manager stop due success time endpoint call: %s" % utils.datetime_now())
                break
            except Exception:
                if self._pubnub.config.reconnect_policy == PNReconnectionPolicy.EXPONENTIAL:
                    logger.debug("reconnect interval increment at: %s" % utils.datetime_now())
                    self._connection_errors += 1

    def start_polling(self):
        if self._pubnub.config.reconnect_policy == PNReconnectionPolicy.NONE:
            logger.warn("reconnection policy is disabled, please handle reconnection manually.")
            return

        self._pubnub.ioloop.spawn_callback(self._register_heartbeat_timer)

    def stop_polling(self):
        if self._cancelled_event is not None and not self._cancelled_event.is_set():
            self._cancelled_event.set()
コード例 #22
0
class SubscribeListener(SubscribeCallback):
    def __init__(self):
        self.connected = False
        self.connected_event = Event()
        self.disconnected_event = Event()
        self.presence_queue = Queue()
        self.message_queue = Queue()
        self.error_queue = Queue()

    def status(self, pubnub, status):
        if utils.is_subscribed_event(status) and not self.connected_event.is_set():
            self.connected_event.set()
        elif utils.is_unsubscribed_event(status) and not self.disconnected_event.is_set():
            self.disconnected_event.set()
        elif status.is_error():
            self.error_queue.put_nowait(status.error_data.exception)

    def message(self, pubnub, message):
        self.message_queue.put(message)

    def presence(self, pubnub, presence):
        self.presence_queue.put(presence)

    @tornado.gen.coroutine
    def _wait_for(self, coro):
        error = self.error_queue.get()
        wi = tornado.gen.WaitIterator(coro, error)

        while not wi.done():
            result = yield wi.next()

            if wi.current_future == coro:
                raise gen.Return(result)
            elif wi.current_future == error:
                raise result
            else:
                raise Exception("Unexpected future resolved: %s" % str(wi.current_future))

    @tornado.gen.coroutine
    def wait_for_connect(self):
        if not self.connected_event.is_set():
            yield self._wait_for(self.connected_event.wait())
        else:
            raise Exception("instance is already connected")

    @tornado.gen.coroutine
    def wait_for_disconnect(self):
        if not self.disconnected_event.is_set():
            yield self._wait_for(self.disconnected_event.wait())
        else:
            raise Exception("instance is already disconnected")

    @tornado.gen.coroutine
    def wait_for_message_on(self, *channel_names):
        channel_names = list(channel_names)
        while True:
            try:  # NOQA
                env = yield self._wait_for(self.message_queue.get())
                if env.channel in channel_names:
                    raise tornado.gen.Return(env)
                else:
                    continue
            finally:
                self.message_queue.task_done()

    @tornado.gen.coroutine
    def wait_for_presence_on(self, *channel_names):
        channel_names = list(channel_names)
        while True:
            try:
                try:
                    env = yield self._wait_for(self.presence_queue.get())
                except:  # NOQA E722 pylint: disable=W0702
                    break
                if env.channel in channel_names:
                    raise tornado.gen.Return(env)
                else:
                    continue
            finally:
                self.presence_queue.task_done()
コード例 #23
0
class SlaveHolder:
    def __init__(self, db, queue):
        self.db = db
        self.slaves = {}
        self._finished = Event()
        self._finished.set()
        self.queue = queue

    @coroutine
    def start(self):
        self._finished.clear()
        logging.debug('Starting slave-holder')

        cur = yield self.db.execute('SELECT * FROM registered_bots WHERE active = TRUE')
        columns = [i[0] for i in cur.description]

        while True:
            row = cur.fetchone()
            if not row:
                break

            row = dict(zip(columns, row))
            self._start_bot(**row)

        listen_future = self.queue.listen(slaveholder_queues(), self.queue_handler)

        try:
            yield self._finished.wait()
        finally:
            self.queue.stop(slaveholder_queues())
            yield listen_future

    def _start_bot(self, **kwargs):
        @coroutine
        def listen_done(f: Future):
            logging.debug('[bot#%s] Terminated', kwargs['id'])
            e = f.exception()
            if e:
                logging.debug('[bot#%s] Got exception: %s %s', kwargs['id'], format_exception(*f.exc_info()))
                if isinstance(e, ApiError) and e.code == 401:
                    logging.warning('[bot#%d] Disabling due to connection error', kwargs['id'])
                    yield self.queue.send(QUEUE_BOTERATOR_BOT_REVOKE, dumps(dict(error=str(e), **kwargs)))
                elif isinstance(e, ApiError) and e.code == 400 and 'chat not found' in e.description and \
                    str(kwargs['moderator_chat_id']) in e.request_body:
                    logging.warning('[bot#%d] Disabling due to unavailable moderator chat', kwargs['id'])
                    yield self.queue.send(QUEUE_BOTERATOR_BOT_REVOKE, dumps(dict(error=str(e), **kwargs)))
                elif isinstance(e, ApiError) and e.code == 409 and 'webhook is active' in e.description:
                    logging.warning('[bot#%d] Disabling due to misconfigured webhook', kwargs['id'])
                    yield self.queue.send(QUEUE_BOTERATOR_BOT_REVOKE, dumps(dict(error=str(e), **kwargs)))
                else:
                    IOLoop.current().add_timeout(timedelta(seconds=5), self._start_bot, **kwargs)

            del self.slaves[kwargs['id']]

        slave = Slave(db=self.db, **kwargs)
        slave_listen_f = slave.start()
        self.slaves[kwargs['id']] = {
            'future': slave_listen_f,
            'instance': slave,
        }
        IOLoop.current().add_future(slave_listen_f, listen_done)

    def stop(self):
        logging.info('Stopping slave-holder')
        for slave in self.slaves.values():
            slave['instance'].stop()

        self._finished.set()

    @coroutine
    def queue_handler(self, queue_name, body):
        body = loads(body.decode('utf-8'))

        if queue_name == QUEUE_SLAVEHOLDER_NEW_BOT:
            self._start_bot(**body)
        elif queue_name == QUEUE_SLAVEHOLDER_GET_BOT_INFO:
            bot = Api(body['token'], lambda x: None)

            if bot.bot_id in self.slaves:
                logging.debug('[bot#%s] Already registered', bot.bot_id)
                yield self.queue.send(body['reply_to'], dumps(dict(error='duplicate')))

            try:
                ret = yield bot.get_me()
                logging.debug('[bot#%s] Ok', bot.bot_id)
            except Exception as e:
                logging.debug('[bot#%s] Failed', bot.bot_id)
                yield self.queue.send(body['reply_to'], dumps(dict(error=str(e))))
                return

            yield self.queue.send(body['reply_to'], dumps(ret))
        elif queue_name == QUEUE_SLAVEHOLDER_GET_MODERATION_GROUP:
            update_with_command_f = Future()
            timeout_f = with_timeout(timedelta(seconds=body['timeout']), update_with_command_f)

            @coroutine
            def slave_update_handler(update):
                logging.debug('[bot#%s] Received update', bot.bot_id)
                if attach_cmd_filter.test(**update):
                    logging.debug('[bot#%s] /attach', bot.bot_id)
                    update_with_command_f.set_result(update)
                elif bot_added.test(**update):
                    logging.debug('[bot#%s] bot added to a group', bot.bot_id)
                    update_with_command_f.set_result(update)
                elif CommandFilterGroupChatCreated.test(**update) or CommandFilterSupergroupChatCreated.test(**update):
                    logging.debug('[bot#%s] group created', bot.bot_id)
                    update_with_command_f.set_result(update)
                else:
                    logging.debug('[bot#%s] unsupported update: %s', dumps(update, indent=2))

            bot = Api(body['token'], slave_update_handler)

            @coroutine
            def handle_finish(f):
                bot.stop()
                if not f.exception():
                    logging.debug('[bot#%s] Done', bot.bot_id)
                    update = f.result()
                    yield self.queue.send(body['reply_to'], dumps(dict(sender=update['message']['from'],
                                                                       **update['message']['chat'])))

                    # Mark last update as read
                    f2 = bot.get_updates(update['update_id'] + 1, timeout=0, retry_on_nonuser_error=True)
                    f2.add_done_callback(lambda x: x.exception())  # Ignore any exceptions
                else:
                    logging.debug('[bot#%s] Failed: %s', bot.bot_id, f.exception())

            timeout_f.add_done_callback(handle_finish)

            attach_cmd_filter = CommandFilterTextCmd('/attach')
            bot_added = CommandFilterNewChatMember(bot.bot_id)

            logging.debug('[bot#%s] Waiting for moderation group', bot.bot_id)
            bot.wait_commands()
        else:
            raise Exception('Unknown queue: %s', queue_name)
コード例 #24
0
class CommandHelper:
    def __init__(self, config):
        self.server = config.get_server()
        self.debug_enabled = config.getboolean('enable_repo_debug', False)
        if self.debug_enabled:
            logging.warn("UPDATE MANAGER: REPO DEBUG ENABLED")
        shell_command = self.server.lookup_plugin('shell_command')
        self.build_shell_command = shell_command.build_shell_command

        AsyncHTTPClient.configure(None, defaults=dict(user_agent="Moonraker"))
        self.http_client = AsyncHTTPClient()

        # GitHub API Rate Limit Tracking
        self.gh_rate_limit = None
        self.gh_limit_remaining = None
        self.gh_limit_reset_time = None
        self.gh_init_evt = Event()

        # Update In Progress Tracking
        self.cur_update_app = self.cur_update_id = None

    def get_server(self):
        return self.server

    def is_debug_enabled(self):
        return self.debug_enabled

    def set_update_info(self, app, uid):
        self.cur_update_app = app
        self.cur_update_id = uid

    def clear_update_info(self):
        self.cur_update_app = self.cur_update_id = None

    def is_app_updating(self, app_name):
        return self.cur_update_app == app_name

    def is_update_busy(self):
        return self.cur_update_app is not None

    def get_rate_limit_stats(self):
        return {
            'github_rate_limit': self.gh_rate_limit,
            'github_requests_remaining': self.gh_limit_remaining,
            'github_limit_reset_time': self.gh_limit_reset_time,
        }

    async def init_api_rate_limit(self):
        url = "https://api.github.com/rate_limit"
        while 1:
            try:
                resp = await self.github_api_request(url, is_init=True)
                core = resp['resources']['core']
                self.gh_rate_limit = core['limit']
                self.gh_limit_remaining = core['remaining']
                self.gh_limit_reset_time = core['reset']
            except Exception:
                logging.exception("Error Initializing GitHub API Rate Limit")
                await tornado.gen.sleep(30.)
            else:
                reset_time = time.ctime(self.gh_limit_reset_time)
                logging.info(
                    "GitHub API Rate Limit Initialized\n"
                    f"Rate Limit: {self.gh_rate_limit}\n"
                    f"Rate Limit Remaining: {self.gh_limit_remaining}\n"
                    f"Rate Limit Reset Time: {reset_time}, "
                    f"Seconds Since Epoch: {self.gh_limit_reset_time}")
                break
        self.gh_init_evt.set()

    async def run_cmd(self,
                      cmd,
                      timeout=10.,
                      notify=False,
                      retries=1,
                      env=None):
        cb = self.notify_update_response if notify else None
        scmd = self.build_shell_command(cmd, callback=cb, env=env)
        while retries:
            if await scmd.run(timeout=timeout):
                break
            retries -= 1
        if not retries:
            raise self.server.error("Shell Command Error")

    async def run_cmd_with_response(self, cmd, timeout=10., env=None):
        scmd = self.build_shell_command(cmd, None, env=env)
        result = await scmd.run_with_response(timeout, retries=5)
        if result is None:
            raise self.server.error(f"Error Running Command: {cmd}")
        return result

    async def github_api_request(self, url, etag=None, is_init=False):
        if not is_init:
            timeout = time.time() + 30.
            try:
                await self.gh_init_evt.wait(timeout)
            except Exception:
                raise self.server.error("Timeout while waiting for GitHub "
                                        "API Rate Limit initialization")
        if self.gh_limit_remaining == 0:
            curtime = time.time()
            if curtime < self.gh_limit_reset_time:
                raise self.server.error(
                    f"GitHub Rate Limit Reached\nRequest: {url}\n"
                    f"Limit Reset Time: {time.ctime(self.gh_limit_remaining)}")
        headers = {"Accept": "application/vnd.github.v3+json"}
        if etag is not None:
            headers['If-None-Match'] = etag
        retries = 5
        while retries:
            try:
                timeout = time.time() + 10.
                fut = self.http_client.fetch(url,
                                             headers=headers,
                                             connect_timeout=5.,
                                             request_timeout=5.,
                                             raise_error=False)
                resp = await tornado.gen.with_timeout(timeout, fut)
            except Exception:
                retries -= 1
                msg = f"Error Processing GitHub API request: {url}"
                if not retries:
                    raise self.server.error(msg)
                logging.exception(msg)
                await tornado.gen.sleep(1.)
                continue
            etag = resp.headers.get('etag', None)
            if etag is not None:
                if etag[:2] == "W/":
                    etag = etag[2:]
            logging.info("GitHub API Request Processed\n"
                         f"URL: {url}\n"
                         f"Response Code: {resp.code}\n"
                         f"Response Reason: {resp.reason}\n"
                         f"ETag: {etag}")
            if resp.code == 403:
                raise self.server.error(
                    f"Forbidden GitHub Request: {resp.reason}")
            elif resp.code == 304:
                logging.info(f"Github Request not Modified: {url}")
                return None
            if resp.code != 200:
                retries -= 1
                if not retries:
                    raise self.server.error(
                        f"Github Request failed: {resp.code} {resp.reason}")
                logging.info(
                    f"Github request error, {retries} retries remaining")
                await tornado.gen.sleep(1.)
                continue
            # Update rate limit on return success
            if 'X-Ratelimit-Limit' in resp.headers and not is_init:
                self.gh_rate_limit = int(resp.headers['X-Ratelimit-Limit'])
                self.gh_limit_remaining = int(
                    resp.headers['X-Ratelimit-Remaining'])
                self.gh_limit_reset_time = float(
                    resp.headers['X-Ratelimit-Reset'])
            decoded = json.loads(resp.body)
            decoded['etag'] = etag
            return decoded

    async def http_download_request(self, url):
        retries = 5
        while retries:
            try:
                timeout = time.time() + 130.
                fut = self.http_client.fetch(
                    url,
                    headers={"Accept": "application/zip"},
                    connect_timeout=5.,
                    request_timeout=120.)
                resp = await tornado.gen.with_timeout(timeout, fut)
            except Exception:
                retries -= 1
                logging.exception("Error Processing Download")
                if not retries:
                    raise
                await tornado.gen.sleep(1.)
                continue
            return resp.body

    def notify_update_response(self, resp, is_complete=False):
        resp = resp.strip()
        if isinstance(resp, bytes):
            resp = resp.decode()
        notification = {
            'message': resp,
            'application': self.cur_update_app,
            'proc_id': self.cur_update_id,
            'complete': is_complete
        }
        self.server.send_event("update_manager:update_response", notification)

    def close(self):
        self.http_client.close()
コード例 #25
0
ファイル: websocket_test.py プロジェクト: leeclemens/tornado
class WebSocketTest(WebSocketBaseTestCase):
    def get_app(self):
        self.close_future = Future()
        return Application([
            ('/echo', EchoHandler, dict(close_future=self.close_future)),
            ('/non_ws', NonWebSocketHandler),
            ('/header', HeaderHandler, dict(close_future=self.close_future)),
            ('/header_echo', HeaderEchoHandler,
             dict(close_future=self.close_future)),
            ('/close_reason', CloseReasonHandler,
             dict(close_future=self.close_future)),
            ('/error_in_on_message', ErrorInOnMessageHandler,
             dict(close_future=self.close_future)),
            ('/async_prepare', AsyncPrepareHandler,
             dict(close_future=self.close_future)),
            ('/path_args/(.*)', PathArgsHandler,
             dict(close_future=self.close_future)),
            ('/coroutine', CoroutineOnMessageHandler,
             dict(close_future=self.close_future)),
            ('/render', RenderMessageHandler,
             dict(close_future=self.close_future)),
            ('/subprotocol', SubprotocolHandler,
             dict(close_future=self.close_future)),
            ('/open_coroutine', OpenCoroutineHandler,
             dict(close_future=self.close_future, test=self)),
        ], template_loader=DictLoader({
            'message.html': '<b>{{ message }}</b>',
        }))

    def get_http_client(self):
        # These tests require HTTP/1; force the use of SimpleAsyncHTTPClient.
        return SimpleAsyncHTTPClient()

    def tearDown(self):
        super(WebSocketTest, self).tearDown()
        RequestHandler._template_loaders.clear()

    def test_http_request(self):
        # WS server, HTTP client.
        response = self.fetch('/echo')
        self.assertEqual(response.code, 400)

    def test_missing_websocket_key(self):
        response = self.fetch('/echo',
                              headers={'Connection': 'Upgrade',
                                       'Upgrade': 'WebSocket',
                                       'Sec-WebSocket-Version': '13'})
        self.assertEqual(response.code, 400)

    def test_bad_websocket_version(self):
        response = self.fetch('/echo',
                              headers={'Connection': 'Upgrade',
                                       'Upgrade': 'WebSocket',
                                       'Sec-WebSocket-Version': '12'})
        self.assertEqual(response.code, 426)

    @gen_test
    def test_websocket_gen(self):
        ws = yield self.ws_connect('/echo')
        yield ws.write_message('hello')
        response = yield ws.read_message()
        self.assertEqual(response, 'hello')
        yield self.close(ws)

    def test_websocket_callbacks(self):
        websocket_connect(
            'ws://127.0.0.1:%d/echo' % self.get_http_port(),
            callback=self.stop)
        ws = self.wait().result()
        ws.write_message('hello')
        ws.read_message(self.stop)
        response = self.wait().result()
        self.assertEqual(response, 'hello')
        self.close_future.add_done_callback(lambda f: self.stop())
        ws.close()
        self.wait()

    @gen_test
    def test_binary_message(self):
        ws = yield self.ws_connect('/echo')
        ws.write_message(b'hello \xe9', binary=True)
        response = yield ws.read_message()
        self.assertEqual(response, b'hello \xe9')
        yield self.close(ws)

    @gen_test
    def test_unicode_message(self):
        ws = yield self.ws_connect('/echo')
        ws.write_message(u'hello \u00e9')
        response = yield ws.read_message()
        self.assertEqual(response, u'hello \u00e9')
        yield self.close(ws)

    @gen_test
    def test_render_message(self):
        ws = yield self.ws_connect('/render')
        ws.write_message('hello')
        response = yield ws.read_message()
        self.assertEqual(response, '<b>hello</b>')
        yield self.close(ws)

    @gen_test
    def test_error_in_on_message(self):
        ws = yield self.ws_connect('/error_in_on_message')
        ws.write_message('hello')
        with ExpectLog(app_log, "Uncaught exception"):
            response = yield ws.read_message()
        self.assertIs(response, None)
        yield self.close(ws)

    @gen_test
    def test_websocket_http_fail(self):
        with self.assertRaises(HTTPError) as cm:
            yield self.ws_connect('/notfound')
        self.assertEqual(cm.exception.code, 404)

    @gen_test
    def test_websocket_http_success(self):
        with self.assertRaises(WebSocketError):
            yield self.ws_connect('/non_ws')

    @gen_test
    def test_websocket_network_fail(self):
        sock, port = bind_unused_port()
        sock.close()
        with self.assertRaises(IOError):
            with ExpectLog(gen_log, ".*"):
                yield websocket_connect(
                    'ws://127.0.0.1:%d/' % port,
                    connect_timeout=3600)

    @gen_test
    def test_websocket_close_buffered_data(self):
        ws = yield websocket_connect(
            'ws://127.0.0.1:%d/echo' % self.get_http_port())
        ws.write_message('hello')
        ws.write_message('world')
        # Close the underlying stream.
        ws.stream.close()
        yield self.close_future

    @gen_test
    def test_websocket_headers(self):
        # Ensure that arbitrary headers can be passed through websocket_connect.
        ws = yield websocket_connect(
            HTTPRequest('ws://127.0.0.1:%d/header' % self.get_http_port(),
                        headers={'X-Test': 'hello'}))
        response = yield ws.read_message()
        self.assertEqual(response, 'hello')
        yield self.close(ws)

    @gen_test
    def test_websocket_header_echo(self):
        # Ensure that headers can be returned in the response.
        # Specifically, that arbitrary headers passed through websocket_connect
        # can be returned.
        ws = yield websocket_connect(
            HTTPRequest('ws://127.0.0.1:%d/header_echo' % self.get_http_port(),
                        headers={'X-Test-Hello': 'hello'}))
        self.assertEqual(ws.headers.get('X-Test-Hello'), 'hello')
        self.assertEqual(ws.headers.get('X-Extra-Response-Header'), 'Extra-Response-Value')
        yield self.close(ws)

    @gen_test
    def test_server_close_reason(self):
        ws = yield self.ws_connect('/close_reason')
        msg = yield ws.read_message()
        # A message of None means the other side closed the connection.
        self.assertIs(msg, None)
        self.assertEqual(ws.close_code, 1001)
        self.assertEqual(ws.close_reason, "goodbye")
        # The on_close callback is called no matter which side closed.
        code, reason = yield self.close_future
        # The client echoed the close code it received to the server,
        # so the server's close code (returned via close_future) is
        # the same.
        self.assertEqual(code, 1001)

    @gen_test
    def test_client_close_reason(self):
        ws = yield self.ws_connect('/echo')
        ws.close(1001, 'goodbye')
        code, reason = yield self.close_future
        self.assertEqual(code, 1001)
        self.assertEqual(reason, 'goodbye')

    @gen_test
    def test_write_after_close(self):
        ws = yield self.ws_connect('/close_reason')
        msg = yield ws.read_message()
        self.assertIs(msg, None)
        with self.assertRaises(WebSocketClosedError):
            ws.write_message('hello')

    @gen_test
    def test_async_prepare(self):
        # Previously, an async prepare method triggered a bug that would
        # result in a timeout on test shutdown (and a memory leak).
        ws = yield self.ws_connect('/async_prepare')
        ws.write_message('hello')
        res = yield ws.read_message()
        self.assertEqual(res, 'hello')

    @gen_test
    def test_path_args(self):
        ws = yield self.ws_connect('/path_args/hello')
        res = yield ws.read_message()
        self.assertEqual(res, 'hello')

    @gen_test
    def test_coroutine(self):
        ws = yield self.ws_connect('/coroutine')
        # Send both messages immediately, coroutine must process one at a time.
        yield ws.write_message('hello1')
        yield ws.write_message('hello2')
        res = yield ws.read_message()
        self.assertEqual(res, 'hello1')
        res = yield ws.read_message()
        self.assertEqual(res, 'hello2')

    @gen_test
    def test_check_origin_valid_no_path(self):
        port = self.get_http_port()

        url = 'ws://127.0.0.1:%d/echo' % port
        headers = {'Origin': 'http://127.0.0.1:%d' % port}

        ws = yield websocket_connect(HTTPRequest(url, headers=headers))
        ws.write_message('hello')
        response = yield ws.read_message()
        self.assertEqual(response, 'hello')
        yield self.close(ws)

    @gen_test
    def test_check_origin_valid_with_path(self):
        port = self.get_http_port()

        url = 'ws://127.0.0.1:%d/echo' % port
        headers = {'Origin': 'http://127.0.0.1:%d/something' % port}

        ws = yield websocket_connect(HTTPRequest(url, headers=headers))
        ws.write_message('hello')
        response = yield ws.read_message()
        self.assertEqual(response, 'hello')
        yield self.close(ws)

    @gen_test
    def test_check_origin_invalid_partial_url(self):
        port = self.get_http_port()

        url = 'ws://127.0.0.1:%d/echo' % port
        headers = {'Origin': '127.0.0.1:%d' % port}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers))
        self.assertEqual(cm.exception.code, 403)

    @gen_test
    def test_check_origin_invalid(self):
        port = self.get_http_port()

        url = 'ws://127.0.0.1:%d/echo' % port
        # Host is 127.0.0.1, which should not be accessible from some other
        # domain
        headers = {'Origin': 'http://somewhereelse.com'}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers))

        self.assertEqual(cm.exception.code, 403)

    @gen_test
    def test_check_origin_invalid_subdomains(self):
        port = self.get_http_port()

        url = 'ws://localhost:%d/echo' % port
        # Subdomains should be disallowed by default.  If we could pass a
        # resolver to websocket_connect we could test sibling domains as well.
        headers = {'Origin': 'http://subtenant.localhost'}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers))

        self.assertEqual(cm.exception.code, 403)

    @gen_test
    def test_subprotocols(self):
        ws = yield self.ws_connect('/subprotocol', subprotocols=['badproto', 'goodproto'])
        self.assertEqual(ws.selected_subprotocol, 'goodproto')
        res = yield ws.read_message()
        self.assertEqual(res, 'subprotocol=goodproto')
        yield self.close(ws)

    @gen_test
    def test_subprotocols_not_offered(self):
        ws = yield self.ws_connect('/subprotocol')
        self.assertIs(ws.selected_subprotocol, None)
        res = yield ws.read_message()
        self.assertEqual(res, 'subprotocol=None')
        yield self.close(ws)

    @gen_test
    def test_open_coroutine(self):
        self.message_sent = Event()
        ws = yield self.ws_connect('/open_coroutine')
        yield ws.write_message('hello')
        self.message_sent.set()
        res = yield ws.read_message()
        self.assertEqual(res, 'ok')
        yield self.close(ws)
コード例 #26
0
class GitUpdater:
    def __init__(self, umgr, config, path=None, env=None):
        self.server = umgr.server
        self.execute_cmd = umgr.execute_cmd
        self.execute_cmd_with_response = umgr.execute_cmd_with_response
        self.notify_update_response = umgr.notify_update_response
        self.name = config.get_name().split()[-1]
        self.repo_path = path
        if path is None:
            self.repo_path = config.get('path')
        self.env = config.get("env", env)
        dist_packages = None
        if self.env is not None:
            self.env = os.path.expanduser(self.env)
            dist_packages = config.get('python_dist_packages', None)
            self.python_reqs = os.path.join(self.repo_path,
                                            config.get("requirements"))
        self.origin = config.get("origin").lower()
        self.install_script = config.get('install_script', None)
        if self.install_script is not None:
            self.install_script = os.path.abspath(
                os.path.join(self.repo_path, self.install_script))
        self.venv_args = config.get('venv_args', None)
        self.python_dist_packages = None
        self.python_dist_path = None
        self.env_package_path = None
        if dist_packages is not None:
            self.python_dist_packages = [
                p.strip() for p in dist_packages.split('\n') if p.strip()
            ]
            self.python_dist_path = os.path.abspath(
                config.get('python_dist_path'))
            if not os.path.exists(self.python_dist_path):
                raise config.error(
                    "Invalid path for option 'python_dist_path'")
            self.env_package_path = os.path.abspath(
                os.path.join(os.path.dirname(self.env), "..",
                             config.get('env_package_path')))
        for opt in [
                "repo_path", "env", "python_reqs", "install_script",
                "python_dist_path", "env_package_path"
        ]:
            val = getattr(self, opt)
            if val is None:
                continue
            if not os.path.exists(val):
                raise config.error("Invalid path for option '%s': %s" %
                                   (val, opt))

        self.version = self.cur_hash = "?"
        self.remote_version = self.remote_hash = "?"
        self.init_evt = Event()
        self.refresh_condition = None
        self.debug = umgr.repo_debug
        self.remote = "origin"
        self.branch = "master"
        self.is_valid = self.is_dirty = self.detached = False

    def _get_version_info(self):
        ver_path = os.path.join(self.repo_path, "scripts/version.txt")
        vinfo = {}
        if os.path.isfile(ver_path):
            data = ""
            with open(ver_path, 'r') as f:
                data = f.read()
            try:
                entries = [e.strip() for e in data.split('\n') if e.strip()]
                vinfo = dict([i.split('=') for i in entries])
                vinfo = {
                    k: tuple(re.findall(r"\d+", v))
                    for k, v in vinfo.items()
                }
            except Exception:
                pass
            else:
                self._log_info(f"Version Info Found: {vinfo}")
        vinfo['version'] = tuple(re.findall(r"\d+", self.version))
        return vinfo

    def _log_exc(self, msg, traceback=True):
        log_msg = f"Repo {self.name}: {msg}"
        if traceback:
            logging.exception(log_msg)
        else:
            logging.info(log_msg)
        return self.server.error(msg)

    def _log_info(self, msg):
        log_msg = f"Repo {self.name}: {msg}"
        logging.info(log_msg)

    def _notify_status(self, msg, is_complete=False):
        log_msg = f"Repo {self.name}: {msg}"
        logging.debug(log_msg)
        self.notify_update_response(log_msg, is_complete)

    async def check_initialized(self, timeout=None):
        if self.init_evt.is_set():
            return
        if timeout is not None:
            timeout = IOLoop.current().time() + timeout
        await self.init_evt.wait(timeout)

    async def refresh(self):
        if self.refresh_condition is None:
            self.refresh_condition = Condition()
        else:
            self.refresh_condition.wait()
            return
        try:
            await self._check_version()
        except Exception:
            logging.exception("Error Refreshing git state")
        self.init_evt.set()
        self.refresh_condition.notify_all()
        self.refresh_condition = None

    async def _check_version(self, need_fetch=True):
        self.is_valid = self.detached = False
        self.cur_hash = self.branch = self.remote = "?"
        self.version = self.remote_version = "?"
        try:
            blist = await self.execute_cmd_with_response(
                f"git -C {self.repo_path} branch --list")
            if blist.startswith("fatal:"):
                self._log_info(f"Invalid git repo at path '{self.repo_path}'")
                return
            branch = None
            for b in blist.split("\n"):
                b = b.strip()
                if b[0] == "*":
                    branch = b[2:]
                    break
            if branch is None:
                self._log_info(
                    "Unable to retreive current branch from branch list\n"
                    f"{blist}")
                return
            if "HEAD detached" in branch:
                bparts = branch.split()[-1].strip("()")
                self.remote, self.branch = bparts.split("/")
                self.detached = True
            else:
                self.branch = branch.strip()
                self.remote = await self.execute_cmd_with_response(
                    f"git -C {self.repo_path} config --get"
                    f" branch.{self.branch}.remote")
            if need_fetch:
                await self.execute_cmd(
                    f"git -C {self.repo_path} fetch {self.remote} --prune -q",
                    retries=3)
            remote_url = await self.execute_cmd_with_response(
                f"git -C {self.repo_path} remote get-url {self.remote}")
            cur_hash = await self.execute_cmd_with_response(
                f"git -C {self.repo_path} rev-parse HEAD")
            remote_hash = await self.execute_cmd_with_response(
                f"git -C {self.repo_path} rev-parse "
                f"{self.remote}/{self.branch}")
            repo_version = await self.execute_cmd_with_response(
                f"git -C {self.repo_path} describe --always "
                "--tags --long --dirty")
            remote_version = await self.execute_cmd_with_response(
                f"git -C {self.repo_path} describe {self.remote}/{self.branch}"
                " --always --tags --long")
        except Exception:
            self._log_exc("Error retreiving git info")
            return

        self.is_dirty = repo_version.endswith("dirty")
        versions = []
        for ver in [repo_version, remote_version]:
            tag_version = "?"
            ver_match = re.match(r"v\d+\.\d+\.\d-\d+", ver)
            if ver_match:
                tag_version = ver_match.group()
            versions.append(tag_version)
        self.version, self.remote_version = versions
        self.cur_hash = cur_hash.strip()
        self.remote_hash = remote_hash.strip()
        self._log_info(
            f"Repo Detected:\nPath: {self.repo_path}\nRemote: {self.remote}\n"
            f"Branch: {self.branch}\nRemote URL: {remote_url}\n"
            f"Current SHA: {self.cur_hash}\n"
            f"Remote SHA: {self.remote_hash}\nVersion: {self.version}\n"
            f"Remote Version: {self.remote_version}\n"
            f"Is Dirty: {self.is_dirty}\nIs Detached: {self.detached}")
        if self.debug:
            self.is_valid = True
            self._log_info("Debug enabled, bypassing official repo check")
        elif self.branch == "master" and self.remote == "origin":
            if self.detached:
                self._log_info("Detached HEAD detected, repo invalid")
                return
            remote_url = remote_url.lower()
            if remote_url[-4:] != ".git":
                remote_url += ".git"
            if remote_url == self.origin:
                self.is_valid = True
                self._log_info("Validity check for git repo passed")
            else:
                self._log_info(f"Invalid git origin url '{remote_url}'")
        else:
            self._log_info("Git repo not on offical remote/branch: "
                           f"{self.remote}/{self.branch}")

    async def update(self, update_deps=False):
        await self.check_initialized(20.)
        if self.refresh_condition is not None:
            self.refresh_condition.wait()
        if not self.is_valid:
            raise self._log_exc("Update aborted, repo is not valid", False)
        if self.is_dirty:
            raise self._log_exc("Update aborted, repo is has been modified",
                                False)
        if self.remote_hash == self.cur_hash:
            # No need to update
            return
        self._notify_status("Updating Repo...")
        try:
            if self.detached:
                await self.execute_cmd(
                    f"git -C {self.repo_path} fetch {self.remote} -q",
                    retries=3)
                await self.execute_cmd(f"git -C {self.repo_path} checkout"
                                       f" {self.remote}/{self.branch} -q")
            else:
                await self.execute_cmd(f"git -C {self.repo_path} pull -q",
                                       retries=3)
        except Exception:
            raise self._log_exc("Error running 'git pull'")
        # Check Semantic Versions
        vinfo = self._get_version_info()
        cur_version = vinfo.get('version', ())
        update_deps |= cur_version < vinfo.get('deps_version', ())
        need_env_rebuild = cur_version < vinfo.get('env_version', ())
        if update_deps:
            await self._install_packages()
            await self._update_virtualenv(need_env_rebuild)
        elif need_env_rebuild:
            await self._update_virtualenv(True)
        # Refresh local repo state
        await self._check_version(need_fetch=False)
        if self.name == "moonraker":
            # Launch restart async so the request can return
            # before the server restarts
            self._notify_status("Update Finished...", is_complete=True)
            IOLoop.current().call_later(.1, self.restart_service)
        else:
            await self.restart_service()
            self._notify_status("Update Finished...", is_complete=True)

    async def _install_packages(self):
        if self.install_script is None:
            return
        # Open install file file and read
        inst_path = self.install_script
        if not os.path.isfile(inst_path):
            self._log_info(f"Unable to open install script: {inst_path}")
            return
        with open(inst_path, 'r') as f:
            data = f.read()
        packages = re.findall(r'PKGLIST="(.*)"', data)
        packages = [p.lstrip("${PKGLIST}").strip() for p in packages]
        if not packages:
            self._log_info(f"No packages found in script: {inst_path}")
            return
        # TODO: Log and notify that packages will be installed
        pkgs = " ".join(packages)
        logging.debug(f"Repo {self.name}: Detected Packages: {pkgs}")
        self._notify_status("Installing system dependencies...")
        # Install packages with apt-get
        try:
            await self.execute_cmd(f"{APT_CMD} update",
                                   timeout=300.,
                                   notify=True)
            await self.execute_cmd(f"{APT_CMD} install --yes {pkgs}",
                                   timeout=3600.,
                                   notify=True)
        except Exception:
            self._log_exc("Error updating packages via apt-get")
            return

    async def _update_virtualenv(self, rebuild_env=False):
        if self.env is None:
            return
        # Update python dependencies
        bin_dir = os.path.dirname(self.env)
        env_path = os.path.normpath(os.path.join(bin_dir, ".."))
        if rebuild_env:
            self._notify_status(f"Creating virtualenv at: {env_path}...")
            if os.path.exists(env_path):
                shutil.rmtree(env_path)
            try:
                await self.execute_cmd(
                    f"virtualenv {self.venv_args} {env_path}", timeout=300.)
            except Exception:
                self._log_exc(f"Error creating virtualenv")
                return
            if not os.path.exists(self.env):
                raise self._log_exc("Failed to create new virtualenv", False)
        reqs = self.python_reqs
        if not os.path.isfile(reqs):
            self._log_exc(f"Invalid path to requirements_file '{reqs}'")
            return
        pip = os.path.join(bin_dir, "pip")
        self._notify_status("Updating python packages...")
        try:
            await self.execute_cmd(f"{pip} install -r {reqs}",
                                   timeout=1200.,
                                   notify=True,
                                   retries=3)
        except Exception:
            self._log_exc("Error updating python requirements")
        self._install_python_dist_requirements()

    def _install_python_dist_requirements(self):
        dist_reqs = self.python_dist_packages
        if dist_reqs is None:
            return
        dist_path = self.python_dist_path
        site_path = self.env_package_path
        for pkg in dist_reqs:
            for f in os.listdir(dist_path):
                if f.startswith(pkg):
                    src = os.path.join(dist_path, f)
                    dest = os.path.join(site_path, f)
                    self._notify_status(f"Linking to dist package: {pkg}")
                    if os.path.islink(dest):
                        os.remove(dest)
                    elif os.path.exists(dest):
                        self._notify_status(
                            f"Error symlinking dist package: {pkg}, "
                            f"file already exists: {dest}")
                        continue
                    os.symlink(src, dest)
                    break

    async def restart_service(self):
        self._notify_status("Restarting Service...")
        try:
            await self.execute_cmd(f"sudo systemctl restart {self.name}")
        except Exception:
            raise self._log_exc("Error restarting service")

    def get_update_status(self):
        return {
            'remote_alias': self.remote,
            'branch': self.branch,
            'version': self.version,
            'remote_version': self.remote_version,
            'current_hash': self.cur_hash,
            'remote_hash': self.remote_hash,
            'is_dirty': self.is_dirty,
            'is_valid': self.is_valid,
            'detached': self.detached,
            'debug_enabled': self.debug
        }
コード例 #27
0
ファイル: queues.py プロジェクト: FlorianLudwig/tornado
class Queue(object):
    """Coordinate producer and consumer coroutines.

    If maxsize is 0 (the default) the queue size is unbounded.

    .. testcode::

        from tornado import gen
        from tornado.ioloop import IOLoop
        from tornado.queues import Queue

        q = Queue(maxsize=2)

        @gen.coroutine
        def consumer():
            while True:
                item = yield q.get()
                try:
                    print('Doing work on %s' % item)
                    yield gen.sleep(0.01)
                finally:
                    q.task_done()

        @gen.coroutine
        def producer():
            for item in range(5):
                yield q.put(item)
                print('Put %s' % item)

        @gen.coroutine
        def main():
            # Start consumer without waiting (since it never finishes).
            IOLoop.current().spawn_callback(consumer)
            yield producer()     # Wait for producer to put all tasks.
            yield q.join()       # Wait for consumer to finish all tasks.
            print('Done')

        IOLoop.current().run_sync(main)

    .. testoutput::

        Put 0
        Put 1
        Doing work on 0
        Put 2
        Doing work on 1
        Put 3
        Doing work on 2
        Put 4
        Doing work on 3
        Doing work on 4
        Done

    In Python 3.5, `Queue` implements the async iterator protocol, so
    ``consumer()`` could be rewritten as::

        async def consumer():
            async for item in q:
                try:
                    print('Doing work on %s' % item)
                    yield gen.sleep(0.01)
                finally:
                    q.task_done()

    .. versionchanged:: 4.3
       Added ``async for`` support in Python 3.5.

    """
    def __init__(self, maxsize=0):
        if maxsize is None:
            raise TypeError("maxsize can't be None")

        if maxsize < 0:
            raise ValueError("maxsize can't be negative")

        self._maxsize = maxsize
        self._init()
        self._getters = collections.deque([])  # Futures.
        self._putters = collections.deque([])  # Pairs of (item, Future).
        self._unfinished_tasks = 0
        self._finished = Event()
        self._finished.set()

    @property
    def maxsize(self):
        """Number of items allowed in the queue."""
        return self._maxsize

    def qsize(self):
        """Number of items in the queue."""
        return len(self._queue)

    def empty(self):
        return not self._queue

    def full(self):
        if self.maxsize == 0:
            return False
        else:
            return self.qsize() >= self.maxsize

    def put(self, item, timeout=None):
        """Put an item into the queue, perhaps waiting until there is room.

        Returns a Future, which raises `tornado.util.TimeoutError` after a
        timeout.

        ``timeout`` may be a number denoting a time (on the same
        scale as `tornado.ioloop.IOLoop.time`, normally `time.time`), or a
        `datetime.timedelta` object for a deadline relative to the
        current time.
        """
        try:
            self.put_nowait(item)
        except QueueFull:
            future = Future()
            self._putters.append((item, future))
            _set_timeout(future, timeout)
            return future
        else:
            return gen._null_future

    def put_nowait(self, item):
        """Put an item into the queue without blocking.

        If no free slot is immediately available, raise `QueueFull`.
        """
        self._consume_expired()
        if self._getters:
            assert self.empty(), "queue non-empty, why are getters waiting?"
            getter = self._getters.popleft()
            self.__put_internal(item)
            getter.set_result(self._get())
        elif self.full():
            raise QueueFull
        else:
            self.__put_internal(item)

    def get(self, timeout=None):
        """Remove and return an item from the queue.

        Returns a Future which resolves once an item is available, or raises
        `tornado.util.TimeoutError` after a timeout.

        ``timeout`` may be a number denoting a time (on the same
        scale as `tornado.ioloop.IOLoop.time`, normally `time.time`), or a
        `datetime.timedelta` object for a deadline relative to the
        current time.
        """
        future = Future()
        try:
            future.set_result(self.get_nowait())
        except QueueEmpty:
            self._getters.append(future)
            _set_timeout(future, timeout)
        return future

    def get_nowait(self):
        """Remove and return an item from the queue without blocking.

        Return an item if one is immediately available, else raise
        `QueueEmpty`.
        """
        self._consume_expired()
        if self._putters:
            assert self.full(), "queue not full, why are putters waiting?"
            item, putter = self._putters.popleft()
            self.__put_internal(item)
            putter.set_result(None)
            return self._get()
        elif self.qsize():
            return self._get()
        else:
            raise QueueEmpty

    def task_done(self):
        """Indicate that a formerly enqueued task is complete.

        Used by queue consumers. For each `.get` used to fetch a task, a
        subsequent call to `.task_done` tells the queue that the processing
        on the task is complete.

        If a `.join` is blocking, it resumes when all items have been
        processed; that is, when every `.put` is matched by a `.task_done`.

        Raises `ValueError` if called more times than `.put`.
        """
        if self._unfinished_tasks <= 0:
            raise ValueError('task_done() called too many times')
        self._unfinished_tasks -= 1
        if self._unfinished_tasks == 0:
            self._finished.set()

    def join(self, timeout=None):
        """Block until all items in the queue are processed.

        Returns a Future, which raises `tornado.util.TimeoutError` after a
        timeout.
        """
        return self._finished.wait(timeout)

    @gen.coroutine
    def __aiter__(self):
        return _QueueIterator(self)

    # These three are overridable in subclasses.
    def _init(self):
        self._queue = collections.deque()

    def _get(self):
        return self._queue.popleft()

    def _put(self, item):
        self._queue.append(item)
    # End of the overridable methods.

    def __put_internal(self, item):
        self._unfinished_tasks += 1
        self._finished.clear()
        self._put(item)

    def _consume_expired(self):
        # Remove timed-out waiters.
        while self._putters and self._putters[0][1].done():
            self._putters.popleft()

        while self._getters and self._getters[0].done():
            self._getters.popleft()

    def __repr__(self):
        return '<%s at %s %s>' % (
            type(self).__name__, hex(id(self)), self._format())

    def __str__(self):
        return '<%s %s>' % (type(self).__name__, self._format())

    def _format(self):
        result = 'maxsize=%r' % (self.maxsize, )
        if getattr(self, '_queue', None):
            result += ' queue=%r' % self._queue
        if self._getters:
            result += ' getters[%s]' % len(self._getters)
        if self._putters:
            result += ' putters[%s]' % len(self._putters)
        if self._unfinished_tasks:
            result += ' tasks=%s' % self._unfinished_tasks
        return result
コード例 #28
0
class XEngineOperations(object):

    def __init__(self, corr_obj):
        """
        A collection of x-engine operations that act on/with a correlator
        instance.
        :param corr_obj: the FxCorrelator instance
        :return:
        """
        self.corr = corr_obj
        self.hosts = corr_obj.xhosts
        self.logger = corr_obj.logger
        self.data_stream = None

        self.vacc_synch_running = IOLoopEvent()
        self.vacc_synch_running.clear()
        self.vacc_check_enabled = IOLoopEvent()
        self.vacc_check_enabled.clear()
        self.vacc_check_cb = None
        self.vacc_check_cb_data = None

    @staticmethod
    def _gberst(hosts, state):
        THREADED_FPGA_OP(
            hosts, timeout=5,
            target_function=(
                lambda fpga_:
                fpga_.registers.control.write(gbe_rst=state),))

    def initialise_post_gbe(self):
        """
        Perform post-gbe setup initialisation steps
        :return:
        """
        # write the board IDs to the xhosts
        board_id = 0
        for f in self.hosts:
            f.registers.board_id.write(reg=board_id)
            board_id += 1

        # write the data stream destination to the registers
        self.write_data_stream_destination(None)

        # clear gbe status
        THREADED_FPGA_OP(
            self.hosts, timeout=5,
            target_function=(
                lambda fpga_:
                fpga_.registers.control.write(gbe_debug_rst='pulse'),))

        # release cores from reset
        XEngineOperations._gberst(self.hosts, False)

        # simulator
        if use_xeng_sim:
            THREADED_FPGA_OP(
                self.hosts, timeout=5,
                target_function=(
                    lambda fpga_: fpga_.registers.simulator.write(en=True),))

        # set up accumulation length
        self.set_acc_len(vacc_resync=False)

        # clear general status
        THREADED_FPGA_OP(
            self.hosts, timeout=5,
            target_function=(
                lambda fpga_:
                fpga_.registers.control.write(status_clr='pulse'),))

        # check for errors
        # TODO - read status regs?

    def initialise_pre_gbe(self):
        """
        Set up x-engines on this device.
        :return:
        """
        # simulator
        if use_xeng_sim:
            THREADED_FPGA_OP(
                self.hosts, timeout=5,
                target_function=(
                    lambda fpga_: fpga_.registers.simulator.write(
                        en=False, rst='pulse'),))

        # set the gapsize register
        gapsize = int(self.corr.configd['xengine']['10gbe_pkt_gapsize'])
        self.logger.info('X-engines: setting packet gap size to %i' % gapsize)
        if 'gapsize' in self.hosts[0].registers.names():
            # these versions have the correct logic surrounding the register
            THREADED_FPGA_OP(
                self.hosts, timeout=5,
                target_function=(
                    lambda fpga_: fpga_.registers.gapsize.write_int(gapsize),))
        elif 'gap_size' in self.hosts[0].registers.names():
            # these versions do not, they need a software hack for the setting
            # to 'take'
            THREADED_FPGA_OP(
                self.hosts, timeout=5,
                target_function=(
                    lambda fpga_: fpga_.registers.gap_size.write_int(gapsize),))
            # HACK - this is a hack to overcome broken x-engine firmware in
            # versions around a2d0615bc9cd95eabf7c8ed922c1a15658c0688e.
            # The logic next to the gap_size register is broken, registering
            # the LAST value written, not the new one.
            THREADED_FPGA_OP(
                self.hosts, timeout=5,
                target_function=(
                    lambda fpga_: fpga_.registers.gap_size.write_int(
                        gapsize-1),))
            # /HACK
        else:
            _errmsg = 'X-engine image has no register gap_size/gapsize?'
            self.logger.exception(_errmsg)
            raise RuntimeError(_errmsg)

        # disable transmission, place cores in reset, and give control
        # register a known state
        self.xeng_tx_disable(None)

        XEngineOperations._gberst(self.hosts, True)

        self.clear_status_all()

    def configure(self):
        """
        Configure the xengine operations - this is done whenever a correlator
        is instantiated.
        :return:
        """
        # set up the xengine data stream
        self._setup_data_stream()

    def _setup_data_stream(self):
        """
        Set up the data stream for the xengine output
        :return:
        """
        # the x-engine output data stream setup
        _xeng_d = self.corr.configd['xengine']

        data_addr = NetAddress(_xeng_d['output_destination_ip'],
                               _xeng_d['output_destination_port'])
        meta_addr = NetAddress(_xeng_d['output_destination_ip'],
                               _xeng_d['output_destination_port'])

        xeng_stream = data_stream.DataStream(
            name=_xeng_d['output_products'][0],
            category=data_stream.XENGINE_CROSS_PRODUCTS,
            destination=data_addr,
            meta_destination=meta_addr,
            destination_cb=self.write_data_stream_destination,
            meta_destination_cb=self.spead_meta_issue_all,
            tx_enable_method=self.xeng_tx_enable,
            tx_disable_method=self.xeng_tx_disable)

        self.data_stream = xeng_stream
        self.corr.register_data_stream(xeng_stream)
        self.vacc_check_enabled.clear()
        self.vacc_synch_running.clear()
        if self.vacc_check_cb is not None:
            self.vacc_check_cb.stop()
        self.vacc_check_cb = None

    def _vacc_periodic_check(self):

        self.logger.debug('Checking vacc operation @ %s' % time.ctime())

        if not self.vacc_check_enabled.is_set():
            self.logger.info('Check logic disabled, exiting')
            return

        if self.vacc_synch_running.is_set():
            self.logger.info('vacc_sync is currently running, exiting')
            return

        def get_data():
            """
            Get the relevant data from the X-engine FPGAs
            """
            # older versions had other register names
            _OLD = 'reorderr_timeout0' in self.hosts[0].registers.names()

            def _get_reorder_data(fpga):
                rv = {}
                for _ctr in range(0, fpga.x_per_fpga):
                    if _OLD:
                        _reg = fpga.registers['reorderr_timeout%i' % _ctr]
                        rv['etim%i' % _ctr] = _reg.read()['data']['reg']
                    else:
                        _reg = fpga.registers['reorderr_timedisc%i' % _ctr]
                        rv['etim%i' % _ctr] = _reg.read()['data']['timeout']
                return rv
            reo_data = THREADED_FPGA_OP(self.hosts, timeout=5,
                                        target_function=_get_reorder_data)
            vacc_data = self.vacc_status()
            return {'reorder': reo_data, 'vacc': vacc_data}
    
        def _vacc_data_check(d0, d1):
            # check errors are not incrementing
            for host in self.hosts:
                for xeng in range(0, host.x_per_fpga):
                    status0 = d0[host.host][xeng]
                    status1 = d1[host.host][xeng]
                    if ((status1['errors'] > status0['errors']) or
                            (status0['errors'] != 0)):
                        self.logger.error('    vacc %i on %s has '
                                          'errors' % (xeng, host.host))
                        return False
            # check that the accumulations are ticking over
            for host in self.hosts:
                for xeng in range(0, host.x_per_fpga):
                    status0 = d0[host.host][xeng]
                    status1 = d1[host.host][xeng]
                    if status1['count'] == status0['count']:
                        self.logger.error('    vacc %i on %s is not '
                                          'incrementing' % (xeng, host.host))
                        return False
            return True
    
        def _reorder_data_check(d0, d1):
            for host in self.hosts:
                for ctr in range(0, host.x_per_fpga):
                    reg = 'etim%i' % ctr
                    if d0[host.host][reg] != d1[host.host][reg]:
                        self.logger.error('    %s - vacc check reorder '
                                          'reg %s error' % (host.host, reg))
                        return False
            return True
    
        new_data = get_data()
        # self.logger.info('new_data: %s' % new_data)
    
        if self.vacc_check_cb_data is not None:
            force_sync = False
            # check the vacc status data first
            if not _vacc_data_check(self.vacc_check_cb_data['vacc'],
                                    new_data['vacc']):
                force_sync = True
            # check the reorder data
            if not force_sync:
                if not _reorder_data_check(self.vacc_check_cb_data['reorder'],
                                           new_data['reorder']):
                    force_sync = True
            if force_sync:
                self.logger.error('    forcing vacc sync')
                self.vacc_sync()

        self.corr.logger.debug('scheduled check done @ %s' % time.ctime())
        self.vacc_check_cb_data = new_data

    def vacc_check_timer_stop(self):
        """
        Disable the vacc_check timer
        :return:
        """
        if self.vacc_check_cb is not None:
            self.vacc_check_cb.stop()
        self.vacc_check_cb = None
        self.vacc_check_enabled.clear()
        self.corr.logger.info('vacc check timer stopped')

    def vacc_check_timer_start(self, vacc_check_time=30):
        """
        Set up a periodic check on the vacc operation.
        :param vacc_check_time: the interval, in seconds, at which to check
        :return:
        """
        if not IOLoop.current()._running:
            raise RuntimeError('IOLoop not running, this will not work')
        self.logger.info('xeng_setup_vacc_check_timer: setting up the '
                         'vacc check timer at %i seconds' % vacc_check_time)
        if vacc_check_time < self.get_acc_time():
            raise RuntimeError('A check time smaller than the accumulation'
                               'time makes no sense.')
        if self.vacc_check_cb is not None:
            self.vacc_check_cb.stop()
        self.vacc_check_cb = PeriodicCallback(self._vacc_periodic_check,
                                              vacc_check_time * 1000)
        self.vacc_check_enabled.set()
        self.vacc_check_cb.start()
        self.corr.logger.info('vacc check timer started')

    def write_data_stream_destination(self, data_stream):
        """
        Write the x-engine data stream destination to the hosts.
        :param data_stream - the data stream on which to act
        :return:
        """
        dstrm = data_stream or self.data_stream
        txip = int(dstrm.destination.ip)
        txport = dstrm.destination.port
        try:
            THREADED_FPGA_OP(
                self.hosts, timeout=10,
                target_function=(lambda fpga_:
                                 fpga_.registers.gbe_iptx.write(reg=txip),))
            THREADED_FPGA_OP(
                self.hosts, timeout=10,
                target_function=(lambda fpga_:
                                 fpga_.registers.gbe_porttx.write(reg=txport),))
        except AttributeError:
            self.logger.warning('Writing stream %s destination to '
                                'hardware failed!' % dstrm.name)

        # update meta data on stream destination change
        self.spead_meta_update_stream_destination()
        dstrm.meta_transmit()

        self.logger.info('Wrote stream %s destination to %s in hardware' % (
            dstrm.name, dstrm.destination))

    def clear_status_all(self):
        """
        Clear the various status registers and counters on all the fengines
        :return:
        """
        THREADED_FPGA_FUNC(self.hosts, timeout=10,
                           target_function='clear_status')

    def subscribe_to_multicast(self):
        """
        Subscribe the x-engines to the f-engine output multicast groups -
        each one subscribes to only one group, with data meant only for it.
        :return:
        """
        if self.corr.fengine_output.is_multicast():
            self.logger.info('F > X is multicast from base %s' %
                             self.corr.fengine_output)
            source_address = str(self.corr.fengine_output.ip_address)
            source_bits = source_address.split('.')
            source_base = int(source_bits[3])
            source_prefix = '%s.%s.%s.' % (source_bits[0],
                                           source_bits[1],
                                           source_bits[2])
            source_ctr = 0
            for host_ctr, host in enumerate(self.hosts):
                for gbe in host.tengbes:
                    rxaddress = '%s%d' % (source_prefix,
                                          source_base + source_ctr)
                    gbe.multicast_receive(rxaddress, 0)

                    # CLUDGE
                    source_ctr += 1
                    # source_ctr += 4

                    self.logger.info('\txhost %s %s subscribing to address %s' %
                                     (host.host, gbe.name, rxaddress))
        else:
            self.logger.info('F > X is unicast from base %s' %
                             self.corr.fengine_output)

    def check_rx(self, max_waittime=30):
        """
        Check that the x hosts are receiving data correctly
        :param max_waittime:
        :return:
        """
        self.logger.info('Checking X hosts are receiving data...')
        results = THREADED_FPGA_FUNC(
            self.hosts, timeout=max_waittime+1,
            target_function=('check_rx', (max_waittime,),))
        all_okay = True
        for _v in results.values():
            all_okay = all_okay and _v
        if not all_okay:
            self.logger.error('\tERROR in X-engine rx data.')
        self.logger.info('\tdone.')
        return all_okay

    def vacc_status(self):
        """
        Get a dictionary of the vacc status registers for all
        x-engines.
        :return: {}
        """
        return THREADED_FPGA_FUNC(self.hosts, timeout=10,
                                  target_function='vacc_get_status')

    def _vacc_sync_check_reset(self):
        """
        Do the vaccs need resetting before a synch?
        :return:
        """
        vaccstat = THREADED_FPGA_FUNC(
            self.hosts, timeout=10,
            target_function='vacc_check_arm_load_counts')
        reset_required = False
        for xhost, result in vaccstat.items():
            if result:
                self.logger.info('xeng_vacc_sync: %s has a vacc that '
                                 'needs resetting' % xhost)
                reset_required = True

        if reset_required:
            THREADED_FPGA_FUNC(self.hosts, timeout=10,
                               target_function='vacc_reset')
            vaccstat = THREADED_FPGA_FUNC(
                self.hosts, timeout=10,
                target_function='vacc_check_reset_status')
            for xhost, result in vaccstat.items():
                if not result:
                    errstr = 'xeng_vacc_sync: resetting vaccs on ' \
                             '%s failed.' % xhost
                    self.logger.error(errstr)
                    raise RuntimeError(errstr)

    def _vacc_sync_create_loadtime(self, min_loadtime):
        """
        Calculate the load time for the vacc synch based on a
        given minimum load time
        :param min_loadtime:
        :return: the vacc load time, in seconds since the UNIX epoch
        """
        # how long should we wait for the vacc load
        self.logger.info('Vacc sync time not specified. Syncing in '
                         '%2.2f seconds\' time.' % (min_loadtime*2))
        t_now = time.time()
        vacc_load_time = t_now + (min_loadtime*2)

        if vacc_load_time < (t_now + min_loadtime):
            raise RuntimeError(
                'Cannot load at a time in the past. '
                'Need at least %2.2f seconds lead time. You asked for '
                '%s.%i, and it is now %s.%i.' % (
                    min_loadtime,
                    time.strftime('%H:%M:%S', time.gmtime(vacc_load_time)),
                    (vacc_load_time-int(vacc_load_time))*100,
                    time.strftime('%H:%M:%S', time.gmtime(t_now)),
                    (t_now-int(t_now))*100))

        self.logger.info('    xeng vaccs will sync at %s (in %2.2fs)'
                         % (time.ctime(t_now), vacc_load_time-t_now))
        return vacc_load_time

    def _vacc_sync_calc_load_mcount(self, vacc_loadtime):
        """
        Calculate the loadtime in clock ticks
        :param vacc_loadtime:
        :return:
        """
        ldmcnt = int(self.corr.mcnt_from_time(vacc_loadtime))
        self.logger.debug('$$$$$$$$$$$ - ldmcnt = %i' % ldmcnt)
        _ldmcnt_orig = ldmcnt
        _cfgd = self.corr.configd
        n_chans = int(_cfgd['fengine']['n_chans'])
        xeng_acc_len = int(_cfgd['xengine']['xeng_accumulation_len'])

        quantisation_bits = int(
            numpy.log2(n_chans) + 1 +
            numpy.log2(xeng_acc_len))

        self.logger.debug('$$$$$$$$$$$ - quant bits = %i' % quantisation_bits)
        ldmcnt = ((ldmcnt >> quantisation_bits) + 1) << quantisation_bits
        self.logger.debug('$$$$$$$$$$$ - ldmcnt quantised = %i' % ldmcnt)
        self.logger.debug('$$$$$$$$$$$ - ldmcnt diff = %i' % (
            ldmcnt - _ldmcnt_orig))
        if _ldmcnt_orig > ldmcnt:
            raise RuntimeError('Quantising the ldmcnt has broken it: %i -> '
                               '%i, diff(%i)' % (_ldmcnt_orig, ldmcnt,
                                                 ldmcnt - _ldmcnt_orig))
        time_from_mcnt = self.corr.time_from_mcnt(ldmcnt)
        t_now = time.time()
        if time_from_mcnt <= t_now:
            self.logger.warn('    Warning: the board timestamp has probably'
                             ' wrapped! mcnt_time(%.3f) time.time(%.3f)' %
                             (time_from_mcnt, t_now))
        return ldmcnt

    def _vacc_sync_print_vacc_statuses(self, vstatus):
        """
        Print the vacc statuses to the logger
        :param vstatus:
        :return:
        """
        self.logger.info('vacc statii:')
        for _host in self.hosts:
            self.logger.info('    %s:' % _host.host)
            for _ctr, _status in enumerate(vstatus[_host.host]):
                self.logger.info('        %i: %s' % (_ctr, _status))

    def _vacc_sync_check_counts_initial(self):
        """
        Check the arm and load counts initially
        :return:
        """
        # read the current arm and load counts
        vacc_status = self.vacc_status()
        arm_count0 = vacc_status[self.hosts[0].host][0]['armcount']
        load_count0 = vacc_status[self.hosts[0].host][0]['loadcount']
        # check the xhosts load and arm counts
        for host in self.hosts:
            for status in vacc_status[host.host]:
                _bad_ldcnt = status['loadcount'] != load_count0
                _bad_armcnt = status['armcount'] != arm_count0
                if _bad_ldcnt or _bad_armcnt:
                    _err = 'All hosts do not have matching arm and ' \
                           'load counts.'
                    self.logger.error(_err)
                    self._vacc_sync_print_vacc_statuses(vacc_status)
                    raise RuntimeError(_err)
        self.logger.info('    Before arming: arm_count(%i) load_count(%i)' %
                         (arm_count0, load_count0))
        return arm_count0, load_count0

    def _vacc_sync_check_arm_count(self, armcount_initial):
        """
        Check that the arm count increased
        :return:
        """
        vacc_status = self.vacc_status()
        arm_count_new = vacc_status[self.hosts[0].host][0]['armcount']
        for host in self.hosts:
            for status in vacc_status[host.host]:
                if ((status['armcount'] != arm_count_new) or
                        (status['armcount'] != armcount_initial + 1)):
                    _err = 'xeng_vacc_sync: all hosts do not have ' \
                           'matching arm counts or arm count did ' \
                           'not increase.'
                    self.logger.error(_err)
                    self._vacc_sync_print_vacc_statuses(vacc_status)
                    return False
        self.logger.info('    Done arming')
        return True

    def _vacc_sync_check_loadtimes(self):
        """

        :return:
        """
        lsws = THREADED_FPGA_OP(
            self.hosts, timeout=10,
            target_function=(
                lambda x: x.registers.vacc_time_lsw.read()['data']),)
        msws = THREADED_FPGA_OP(
            self.hosts, timeout=10,
            target_function=(
                lambda x: x.registers.vacc_time_msw.read()['data']),)
        _host0 = self.hosts[0].host
        for host in self.hosts:
            if ((lsws[host.host]['lsw'] != lsws[_host0]['lsw']) or
                    (msws[host.host]['msw'] != msws[_host0]['msw'])):
                _err = 'xeng_vacc_sync: all hosts do not have matching ' \
                       'vacc LSWs and MSWs'
                self.logger.error(_err)
                self.logger.error('LSWs: %s' % lsws)
                self.logger.error('MSWs: %s' % msws)
                vacc_status = self.vacc_status()
                self._vacc_sync_print_vacc_statuses(vacc_status)
                return False
        lsw = lsws[self.hosts[0].host]['lsw']
        msw = msws[self.hosts[0].host]['msw']
        xldtime = (msw << 32) | lsw
        self.logger.info('    x engines have vacc ld time %i' % xldtime)
        return True

    def _vacc_sync_wait_for_arm(self, load_mcount):
        """

        :param load_mcount:
        :return:
        """
        t_now = time.time()
        time_from_mcnt = self.corr.time_from_mcnt(load_mcount)
        wait_time = time_from_mcnt - t_now + 0.2
        if wait_time <= 0:
            self.logger.error('    This is wonky - why is the wait_time '
                              'less than zero? %.3f' % wait_time)
            self.logger.error('    corr synch epoch: %i' %
                              self.corr.get_synch_time())
            self.logger.error('    time.time(): %.10f' % t_now)
            self.logger.error('    time_from_mcnt: %.10f' % time_from_mcnt)
            self.logger.error('    ldmcnt: %i' % load_mcount)
            # hack
            wait_time = t_now + 4

        self.logger.info('    Waiting %2.2f seconds for arm to '
                         'trigger.' % wait_time)
        time.sleep(wait_time)

    def _vacc_sync_check_load_count(self, load_count0):
        """
        Did the vaccs load counts increment correctly?
        :param load_count0:
        :return:
        """
        vacc_status = self.vacc_status()
        load_count_new = vacc_status[self.hosts[0].host][0]['loadcount']
        for host in self.hosts:
            for status in vacc_status[host.host]:
                if ((status['loadcount'] != load_count_new) or
                        (status['loadcount'] != load_count0 + 1)):
                    self.logger.error('vacc did not trigger!')
                    self._vacc_sync_print_vacc_statuses(vacc_status)
                    return False
        self.logger.info('    All vaccs triggered correctly.')
        return True

    def _vacc_sync_final_check(self):
        """
        Check the vacc status, errors and accumulations
        :return:
        """
        self.logger.info('\tChecking for errors & accumulations...')
        vac_okay = self._vacc_check_okay_initial()
        if not vac_okay:
            vacc_status = self.vacc_status()
            vacc_error_detail = THREADED_FPGA_FUNC(
                self.hosts, timeout=5,
                target_function='vacc_get_error_detail')
            self.logger.error('\t\txeng_vacc_sync: exited on vacc error')
            self.logger.error('\t\txeng_vacc_sync: vacc statii:')
            for host, item in vacc_status.items():
                self.logger.error('\t\t\t%s: %s' % (host, str(item)))
            self.logger.error('\t\txeng_vacc_sync: vacc errors:')
            for host, item in vacc_error_detail.items():
                self.logger.error('\t\t\t%s: %s' % (host, str(item)))
            self.logger.error('\t\txeng_vacc_sync: exited on vacc error')
            return False
        self.logger.info('\t...accumulations rolling in without error.')
        return True

    def _vacc_check_okay_initial(self):
        """
        After an initial setup, is the vacc okay?
        Are the error counts zero and the counters
        ticking over?
        :return: True or False
        """
        vacc_status = self.vacc_status()
        note_errors = False
        for host in self.hosts:
            for xeng_ctr, status in enumerate(vacc_status[host.host]):
                _msgpref = '{h}:{x} - '.format(h=host, x=xeng_ctr)
                errs = status['errors']
                thresh = self.corr.qdr_vacc_error_threshold
                if (errs > 0) and (errs < thresh):
                    self.logger.warn(
                        '\t\t{pref}{thresh} > vacc errors > 0. Que '
                        'pasa?'.format(pref=_msgpref, thresh=thresh))
                    note_errors = True
                elif (errs > 0) and (errs >= thresh):
                    self.logger.error(
                        '\t\t{pref}vacc errors > {thresh}. Problems.'.format(
                            pref=_msgpref, thresh=thresh))
                    return False
                if status['count'] <= 0:
                    self.logger.error(
                        '\t\t{}vacc counts <= 0. Que pasa?'.format(_msgpref))
                    return False
        if note_errors:
            # investigate the errors further, what caused them?
            if self._vacc_non_parity_errors():
                self.logger.error('\t\t\tsome vacc errors, but they\'re not '
                                  'parity errors. Problems.')
                return False
            self.logger.info('\t\tvacc_check_okay_initial: mostly okay, some '
                             'QDR parity errors')
        else:
            self.logger.info('\t\tvacc_check_okay_initial: all okay')
        return True

    def _vacc_non_parity_errors(self):
        """
        Are VACC errors other than parity errors occuring?
        :return:
        """
        _loops = 2
        parity_errors = 0
        for ctr in range(_loops):
            detail = THREADED_FPGA_FUNC(
                self.hosts, timeout=5, target_function='vacc_get_error_detail')
            for xhost in detail:
                for vals in detail[xhost]:
                    for field in vals:
                        if vals[field] > 0:
                            if field != 'parity':
                                return True
                            else:
                                parity_errors += 1
            if ctr < _loops - 1:
                time.sleep(self.get_acc_time() * 1.1)
        if parity_errors == 0:
            self.logger.error('\t\tThat\'s odd, VACC errors reported but '
                              'nothing caused them?')
            return True
        return False

    def vacc_sync(self):
        """
        Sync the vector accumulators on all the x-engines.
        Assumes that the x-engines are all receiving data.
        :return: the vacc synch time, in seconds since the UNIX epoch
        """

        if self.vacc_synch_running.is_set():
            self.logger.error('vacc_sync called when it was already running?')
            return
        self.vacc_synch_running.set()
        min_load_time = 2

        attempts = 0
        try:
            while True:
                attempts += 1

                if attempts > MAX_VACC_SYNCH_ATTEMPTS:
                    raise VaccSynchAttemptsMaxedOut(
                        'Reached maximum vacc synch attempts, aborting')

                # check if the vaccs need resetting
                self._vacc_sync_check_reset()

                # estimate the sync time, if needed
                self._vacc_sync_calc_load_mcount(time.time())

                # work out the load time
                vacc_load_time = self._vacc_sync_create_loadtime(min_load_time)

                # set the vacc load time on the xengines
                load_mcount = self._vacc_sync_calc_load_mcount(vacc_load_time)

                # set the load mcount on the x-engines
                self.logger.info('    Applying load time: %i.' % load_mcount)
                THREADED_FPGA_FUNC(
                    self.hosts, timeout=10,
                    target_function=('vacc_set_loadtime', (load_mcount,),))

                # check the current counts
                (arm_count0,
                 load_count0) = self._vacc_sync_check_counts_initial()

                # arm the xhosts
                THREADED_FPGA_FUNC(
                    self.hosts, timeout=10, target_function='vacc_arm')

                # did the arm count increase?
                if not self._vacc_sync_check_arm_count(arm_count0):
                    continue

                # check the the load time was stored correctly
                if not self._vacc_sync_check_loadtimes():
                    continue

                # wait for the vaccs to arm
                self._vacc_sync_wait_for_arm(load_mcount)

                # check the status to see that the load count increased
                if not self._vacc_sync_check_load_count(load_count0):
                    continue

                # allow vacc to flush and correctly populate parity bits:
                self.logger.info('    Waiting %2.2fs for an accumulation to '
                                 'flush, to correctly populate parity bits.' %
                                 self.get_acc_time())
                time.sleep(self.get_acc_time() + 0.2)

                self.logger.info('    Clearing status and reseting counters.')
                THREADED_FPGA_FUNC(self.hosts, timeout=10,
                                   target_function='clear_status')

                # wait for a good accumulation to finish.
                self.logger.info('    Waiting %2.2fs for an accumulation to '
                                 'flush before checking counters.' %
                                 self.get_acc_time())
                time.sleep(self.get_acc_time() + 0.2)

                # check the vacc status, errors and accumulations
                if not self._vacc_sync_final_check():
                    continue

                # done
                synch_time = self.corr.time_from_mcnt(load_mcount)
                self.vacc_synch_running.clear()
                return synch_time
        except KeyboardInterrupt:
            self.vacc_synch_running.clear()
        except VaccSynchAttemptsMaxedOut as e:
            self.vacc_synch_running.clear()
            self.logger.error(e.message)
            raise e

    def set_acc_time(self, acc_time_s, vacc_resync=True):
        """
        Set the vacc accumulation length based on a required dump time,
        in seconds
        :param acc_time_s: new dump time, in seconds
        :param vacc_resync: force a vacc resynchronisation
        :return:
        """
        if use_xeng_sim:
            raise RuntimeError('That\'s not an option anymore.')
        new_acc_len = (
            (self.corr.sample_rate_hz * acc_time_s) /
            (self.corr.xeng_accumulation_len * self.corr.n_chans * 2.0))
        new_acc_len = round(new_acc_len)
        self.corr.logger.info('set_acc_time: %.3fs -> new_acc_len(%i)' %
                              (acc_time_s, new_acc_len))
        self.set_acc_len(new_acc_len, vacc_resync)
        if self.corr.sensor_manager:
            sensor = self.corr.sensor_manager.sensor_get('integration-time')
            sensor.set_value(self.get_acc_time())

    def get_acc_time(self):
        """
        Get the dump time currently being used.
    
        Note: Will only be correct if accumulation time was set using this
        correlator
        object instance since cached values are used for the calculation.
        I.e., the number of accumulations are _not_ read from the FPGAs.
        :return:
        """
        return (self.corr.xeng_accumulation_len * self.corr.accumulation_len *
                self.corr.n_chans * 2.0) / self.corr.sample_rate_hz

    def get_acc_len(self):
        """
        Read the acc len currently programmed into the FPGA.
        :return:
        """
        return self.hosts[0].registers.acc_len.read_uint()

    def set_acc_len(self, acc_len=None, vacc_resync=True):
        """
        Set the QDR vector accumulation length.
        :param acc_len:
        :param vacc_resync: force a vacc resynchronisation
        :return:
        """
        if (acc_len is not None) and (acc_len <= 0):
            _err = 'new acc_len of %i makes no sense' % acc_len
            self.logger.error(_err)
            raise RuntimeError(_err)
        reenable_timer = False
        if self.vacc_check_enabled.is_set():
            self.vacc_check_timer_stop()
            reenable_timer = True
        if acc_len is not None:
            self.corr.accumulation_len = acc_len
        THREADED_FPGA_OP(
            self.hosts, timeout=10,
            target_function=(
                lambda fpga_:
                fpga_.registers.acc_len.write_int(self.corr.accumulation_len),))
        if self.corr.sensor_manager:
            sensor = self.corr.sensor_manager.sensor_get('n-accs')
            sensor.set_value(self.corr.accumulation_len)
        self.logger.info('Set vacc accumulation length %d system-wide '
                         '(%.2f seconds)' %
                         (self.corr.accumulation_len, self.get_acc_time()))
        self.corr.speadops.update_metadata([0x1015, 0x1016])
        if vacc_resync:
            self.vacc_sync()
        if reenable_timer:
            self.vacc_check_timer_start()

    def xeng_tx_enable(self, data_stream):
        """
        Start transmission of data streams from the x-engines
        :param data_stream - the data stream on which to act
        :return:
        """
        dstrm = data_stream or self.data_stream
        THREADED_FPGA_OP(
                self.hosts, timeout=5,
                target_function=(
                    lambda fpga_:
                    fpga_.registers.control.write(gbe_txen=True),))
        self.logger.info('X-engine output enabled')

    def xeng_tx_disable(self, data_stream):
        """
        Start transmission of data streams from the x-engines
        :param data_stream - the data stream on which to act
        :return:
        """
        dstrm = data_stream or self.data_stream
        THREADED_FPGA_OP(
                self.hosts, timeout=5,
                target_function=(
                    lambda fpga_:
                    fpga_.registers.control.write(gbe_txen=False),))
        self.logger.info('X-engine output disabled')

    def spead_meta_update_stream_destination(self):
        """

        :return:
        """
        meta_ig = self.data_stream.meta_ig
        self.corr.speadops.add_item(
            meta_ig,
            name='rx_udp_port', id=0x1022,
            description='Destination UDP port for %s data '
                        'output.' % self.data_stream.name,
            shape=[], format=[('u', SPEAD_ADDRSIZE)],
            value=self.data_stream.destination.port)

        ipstr = numpy.array(str(self.data_stream.destination.ip))
        self.corr.speadops.add_item(
            meta_ig,
            name='rx_udp_ip_str', id=0x1024,
            description='Destination IP address for %s data '
                        'output.' % self.data_stream.name,
            shape=ipstr.shape,
            dtype=ipstr.dtype,
            value=ipstr)

    # x-engine-specific SPEAD operations
    def spead_meta_update_all(self):
        """
        Update metadata for this correlator's xengine output.
        :return:
        """
        meta_ig = self.data_stream.meta_ig

        self.corr.speadops.item_0x1007(meta_ig)

        self.corr.speadops.add_item(
            meta_ig,
            name='n_bls', id=0x1008,
            description='Number of baselines in the data stream.',
            shape=[], format=[('u', SPEAD_ADDRSIZE)],
            value=len(self.corr.baselines))

        self.corr.speadops.add_item(
            meta_ig,
            name='n_chans', id=0x1009,
            description='Number of frequency channels in an integration.',
            shape=[], format=[('u', SPEAD_ADDRSIZE)],
            value=self.corr.n_chans)

        self.corr.speadops.item_0x100a(meta_ig)

        n_xengs = len(self.corr.xhosts) * self.corr.x_per_fpga
        self.corr.speadops.add_item(
            meta_ig,
            name='n_xengs', id=0x100B,
            description='The number of x-engines in the system.',
            shape=[], format=[('u', SPEAD_ADDRSIZE)],
            value=n_xengs)

        bls_ordering = numpy.array(
            [baseline for baseline in self.corr.baselines])
        # this is a list of the baseline stream pairs, e.g. ['ant0x' 'ant0y']
        self.corr.speadops.add_item(
            meta_ig,
            name='bls_ordering', id=0x100C,
            description='The baseline ordering in the output data stream.',
            shape=bls_ordering.shape,
            dtype=bls_ordering.dtype,
            value=bls_ordering)

        self.corr.speadops.item_0x100e(meta_ig)

        self.corr.speadops.add_item(
            meta_ig,
            name='center_freq', id=0x1011,
            description='The on-sky centre-frequency.',
            shape=[], format=[('f', 64)],
            value=int(self.corr.configd['fengine']['true_cf']))

        self.corr.speadops.add_item(
            meta_ig,
            name='bandwidth', id=0x1013,
            description='The input (analogue) bandwidth of the system.',
            shape=[], format=[('f', 64)],
            value=int(self.corr.configd['fengine']['bandwidth']))

        self.corr.speadops.item_0x1015(meta_ig)
        self.corr.speadops.item_0x1016(meta_ig)
        self.corr.speadops.item_0x101e(meta_ig)

        self.corr.speadops.add_item(
            meta_ig,
            name='xeng_acc_len', id=0x101F,
            description='Number of spectra accumulated inside X engine. '
                        'Determines minimum integration time and '
                        'user-configurable integration time stepsize. '
                        'X-engine correlator internals.',
            shape=[], format=[('u', SPEAD_ADDRSIZE)],
            value=self.corr.xeng_accumulation_len)

        self.corr.speadops.item_0x1020(meta_ig)

        pkt_len = int(self.corr.configd['fengine']['10gbe_pkt_len'])
        self.corr.speadops.add_item(
            meta_ig,
            name='feng_pkt_len', id=0x1021,
            description='Payload size of 10GbE packet exchange between '
                        'F and X engines in 64 bit words. Usually equal '
                        'to the number of spectra accumulated inside X '
                        'engine. F-engine correlator internals.',
            shape=[], format=[('u', SPEAD_ADDRSIZE)],
            value=pkt_len)

        self.spead_meta_update_stream_destination()

        port = int(self.corr.configd['fengine']['10gbe_port'])
        self.corr.speadops.add_item(
            meta_ig,
            name='feng_udp_port', id=0x1023,
            description='Port for F-engines 10Gbe links in the system.',
            shape=[], format=[('u', SPEAD_ADDRSIZE)],
            value=port)

        ipstr = numpy.array(self.corr.configd['fengine']['10gbe_start_ip'])
        self.corr.speadops.add_item(
            meta_ig,
            name='feng_start_ip', id=0x1025,
            description='Start IP address for F-engines in the system.',
            shape=ipstr.shape,
            dtype=ipstr.dtype,
            value=ipstr)

        self.corr.speadops.add_item(
            meta_ig,
            name='xeng_rate', id=0x1026,
            description='Target clock rate of processing engines (xeng).',
            shape=[], format=[('u', SPEAD_ADDRSIZE)],
            value=self.corr.xeng_clk)

        self.corr.speadops.item_0x1027(meta_ig)

        x_per_fpga = int(self.corr.configd['xengine']['x_per_fpga'])
        self.corr.speadops.add_item(
            meta_ig,
            name='x_per_fpga', id=0x1041,
            description='Number of X engines per FPGA host.',
            shape=[], format=[('u', SPEAD_ADDRSIZE)],
            value=x_per_fpga)

        self.corr.speadops.add_item(
            meta_ig,
            name='ddc_mix_freq', id=0x1043,
            description='Digital downconverter mixing frequency as a fraction '
                        'of the ADC sampling frequency. eg: 0.25. Set to zero '
                        'if no DDC is present.',
            shape=[], format=[('u', SPEAD_ADDRSIZE)],
            value=0)

        self.corr.speadops.item_0x1045(meta_ig)
        self.corr.speadops.item_0x1046(meta_ig)

        self.corr.speadops.add_item(
            meta_ig,
            name='xeng_out_bits_per_sample', id=0x1048,
            description='The number of bits per value of the xeng '
                        'accumulator output. Note this is for a '
                        'single value, not the combined complex size.',
            shape=[], format=[('u', SPEAD_ADDRSIZE)],
            value=self.corr.xeng_outbits)

        self.corr.speadops.add_item(
            meta_ig,
            name='f_per_fpga', id=0x1049,
            description='Number of F engines per FPGA host.',
            shape=[], format=[('u', SPEAD_ADDRSIZE)],
            value=self.corr.f_per_fpga)

        self.corr.speadops.item_0x104a(meta_ig)
        self.corr.speadops.item_0x104b(meta_ig)

        self.corr.speadops.item_0x1400(meta_ig)

        self.corr.speadops.item_0x1600(meta_ig)

        self.corr.speadops.add_item(
            meta_ig,
            name='flags_xeng_raw', id=0x1601,
            description='Flags associated with xeng_raw data output. '
                        'bit 34 - corruption or data missing during integration'
                        'bit 33 - overrange in data path '
                        'bit 32 - noise diode on during integration '
                        'bits 0 - 31 reserved for internal debugging',
            shape=[], format=[('u', SPEAD_ADDRSIZE)])

        self.corr.speadops.add_item(
            meta_ig,
            name='xeng_raw', id=0x1800,
            description='Raw data for %i xengines in the system. This item '
                        'represents a full spectrum (all frequency channels) '
                        'assembled from lowest frequency to highest '
                        'frequency. Each frequency channel contains the data '
                        'for all baselines (n_bls given by SPEAD ID 0x100b). '
                        'Each value is a complex number - two (real and '
                        'imaginary) unsigned integers.' % n_xengs,
            # dtype=numpy.int32,
            dtype=numpy.dtype('>i4'),
            shape=[self.corr.n_chans, len(self.corr.baselines), 2])
            # shape=[self.corr.n_chans * len(self.corr.baselines), 2])

    def spead_meta_issue_all(self, data_stream):
        """
        Issue = update the metadata then send it.
        :param data_stream: The DataStream object for which to send metadata
        :return: True if the callback transmits the metadata as well
        """
        dstrm = data_stream or self.data_stream
        self.spead_meta_update_all()
        dstrm.meta_transmit()
        self.logger.info('Issued SPEAD data descriptor for data stream %s '
                         'to %s.' % (dstrm.name,
                                     dstrm.meta_destination))
        return True
コード例 #29
0
ファイル: websocket_test.py プロジェクト: lilydjwg/tornado
class WebSocketTest(WebSocketBaseTestCase):
    def get_app(self):
        self.close_future = Future()  # type: Future[None]
        return Application(
            [
                ("/echo", EchoHandler, dict(close_future=self.close_future)),
                ("/non_ws", NonWebSocketHandler),
                ("/header", HeaderHandler, dict(close_future=self.close_future)),
                (
                    "/header_echo",
                    HeaderEchoHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/close_reason",
                    CloseReasonHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/error_in_on_message",
                    ErrorInOnMessageHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/async_prepare",
                    AsyncPrepareHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/path_args/(.*)",
                    PathArgsHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/coroutine",
                    CoroutineOnMessageHandler,
                    dict(close_future=self.close_future),
                ),
                ("/render", RenderMessageHandler, dict(close_future=self.close_future)),
                (
                    "/subprotocol",
                    SubprotocolHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/open_coroutine",
                    OpenCoroutineHandler,
                    dict(close_future=self.close_future, test=self),
                ),
                ("/error_in_open", ErrorInOpenHandler),
                ("/error_in_async_open", ErrorInAsyncOpenHandler),
            ],
            template_loader=DictLoader({"message.html": "<b>{{ message }}</b>"}),
        )

    def get_http_client(self):
        # These tests require HTTP/1; force the use of SimpleAsyncHTTPClient.
        return SimpleAsyncHTTPClient()

    def tearDown(self):
        super(WebSocketTest, self).tearDown()
        RequestHandler._template_loaders.clear()

    def test_http_request(self):
        # WS server, HTTP client.
        response = self.fetch("/echo")
        self.assertEqual(response.code, 400)

    def test_missing_websocket_key(self):
        response = self.fetch(
            "/echo",
            headers={
                "Connection": "Upgrade",
                "Upgrade": "WebSocket",
                "Sec-WebSocket-Version": "13",
            },
        )
        self.assertEqual(response.code, 400)

    def test_bad_websocket_version(self):
        response = self.fetch(
            "/echo",
            headers={
                "Connection": "Upgrade",
                "Upgrade": "WebSocket",
                "Sec-WebSocket-Version": "12",
            },
        )
        self.assertEqual(response.code, 426)

    @gen_test
    def test_websocket_gen(self):
        ws = yield self.ws_connect("/echo")
        yield ws.write_message("hello")
        response = yield ws.read_message()
        self.assertEqual(response, "hello")

    def test_websocket_callbacks(self):
        websocket_connect(
            "ws://127.0.0.1:%d/echo" % self.get_http_port(), callback=self.stop
        )
        ws = self.wait().result()
        ws.write_message("hello")
        ws.read_message(self.stop)
        response = self.wait().result()
        self.assertEqual(response, "hello")
        self.close_future.add_done_callback(lambda f: self.stop())
        ws.close()
        self.wait()

    @gen_test
    def test_binary_message(self):
        ws = yield self.ws_connect("/echo")
        ws.write_message(b"hello \xe9", binary=True)
        response = yield ws.read_message()
        self.assertEqual(response, b"hello \xe9")

    @gen_test
    def test_unicode_message(self):
        ws = yield self.ws_connect("/echo")
        ws.write_message(u"hello \u00e9")
        response = yield ws.read_message()
        self.assertEqual(response, u"hello \u00e9")

    @gen_test
    def test_render_message(self):
        ws = yield self.ws_connect("/render")
        ws.write_message("hello")
        response = yield ws.read_message()
        self.assertEqual(response, "<b>hello</b>")

    @gen_test
    def test_error_in_on_message(self):
        ws = yield self.ws_connect("/error_in_on_message")
        ws.write_message("hello")
        with ExpectLog(app_log, "Uncaught exception"):
            response = yield ws.read_message()
        self.assertIs(response, None)

    @gen_test
    def test_websocket_http_fail(self):
        with self.assertRaises(HTTPError) as cm:
            yield self.ws_connect("/notfound")
        self.assertEqual(cm.exception.code, 404)

    @gen_test
    def test_websocket_http_success(self):
        with self.assertRaises(WebSocketError):
            yield self.ws_connect("/non_ws")

    @gen_test
    def test_websocket_network_fail(self):
        sock, port = bind_unused_port()
        sock.close()
        with self.assertRaises(IOError):
            with ExpectLog(gen_log, ".*"):
                yield websocket_connect(
                    "ws://127.0.0.1:%d/" % port, connect_timeout=3600
                )

    @gen_test
    def test_websocket_close_buffered_data(self):
        ws = yield websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port())
        ws.write_message("hello")
        ws.write_message("world")
        # Close the underlying stream.
        ws.stream.close()

    @gen_test
    def test_websocket_headers(self):
        # Ensure that arbitrary headers can be passed through websocket_connect.
        ws = yield websocket_connect(
            HTTPRequest(
                "ws://127.0.0.1:%d/header" % self.get_http_port(),
                headers={"X-Test": "hello"},
            )
        )
        response = yield ws.read_message()
        self.assertEqual(response, "hello")

    @gen_test
    def test_websocket_header_echo(self):
        # Ensure that headers can be returned in the response.
        # Specifically, that arbitrary headers passed through websocket_connect
        # can be returned.
        ws = yield websocket_connect(
            HTTPRequest(
                "ws://127.0.0.1:%d/header_echo" % self.get_http_port(),
                headers={"X-Test-Hello": "hello"},
            )
        )
        self.assertEqual(ws.headers.get("X-Test-Hello"), "hello")
        self.assertEqual(
            ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value"
        )

    @gen_test
    def test_server_close_reason(self):
        ws = yield self.ws_connect("/close_reason")
        msg = yield ws.read_message()
        # A message of None means the other side closed the connection.
        self.assertIs(msg, None)
        self.assertEqual(ws.close_code, 1001)
        self.assertEqual(ws.close_reason, "goodbye")
        # The on_close callback is called no matter which side closed.
        code, reason = yield self.close_future
        # The client echoed the close code it received to the server,
        # so the server's close code (returned via close_future) is
        # the same.
        self.assertEqual(code, 1001)

    @gen_test
    def test_client_close_reason(self):
        ws = yield self.ws_connect("/echo")
        ws.close(1001, "goodbye")
        code, reason = yield self.close_future
        self.assertEqual(code, 1001)
        self.assertEqual(reason, "goodbye")

    @gen_test
    def test_write_after_close(self):
        ws = yield self.ws_connect("/close_reason")
        msg = yield ws.read_message()
        self.assertIs(msg, None)
        with self.assertRaises(WebSocketClosedError):
            ws.write_message("hello")

    @gen_test
    def test_async_prepare(self):
        # Previously, an async prepare method triggered a bug that would
        # result in a timeout on test shutdown (and a memory leak).
        ws = yield self.ws_connect("/async_prepare")
        ws.write_message("hello")
        res = yield ws.read_message()
        self.assertEqual(res, "hello")

    @gen_test
    def test_path_args(self):
        ws = yield self.ws_connect("/path_args/hello")
        res = yield ws.read_message()
        self.assertEqual(res, "hello")

    @gen_test
    def test_coroutine(self):
        ws = yield self.ws_connect("/coroutine")
        # Send both messages immediately, coroutine must process one at a time.
        yield ws.write_message("hello1")
        yield ws.write_message("hello2")
        res = yield ws.read_message()
        self.assertEqual(res, "hello1")
        res = yield ws.read_message()
        self.assertEqual(res, "hello2")

    @gen_test
    def test_check_origin_valid_no_path(self):
        port = self.get_http_port()

        url = "ws://127.0.0.1:%d/echo" % port
        headers = {"Origin": "http://127.0.0.1:%d" % port}

        ws = yield websocket_connect(HTTPRequest(url, headers=headers))
        ws.write_message("hello")
        response = yield ws.read_message()
        self.assertEqual(response, "hello")

    @gen_test
    def test_check_origin_valid_with_path(self):
        port = self.get_http_port()

        url = "ws://127.0.0.1:%d/echo" % port
        headers = {"Origin": "http://127.0.0.1:%d/something" % port}

        ws = yield websocket_connect(HTTPRequest(url, headers=headers))
        ws.write_message("hello")
        response = yield ws.read_message()
        self.assertEqual(response, "hello")

    @gen_test
    def test_check_origin_invalid_partial_url(self):
        port = self.get_http_port()

        url = "ws://127.0.0.1:%d/echo" % port
        headers = {"Origin": "127.0.0.1:%d" % port}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers))
        self.assertEqual(cm.exception.code, 403)

    @gen_test
    def test_check_origin_invalid(self):
        port = self.get_http_port()

        url = "ws://127.0.0.1:%d/echo" % port
        # Host is 127.0.0.1, which should not be accessible from some other
        # domain
        headers = {"Origin": "http://somewhereelse.com"}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers))

        self.assertEqual(cm.exception.code, 403)

    @gen_test
    def test_check_origin_invalid_subdomains(self):
        port = self.get_http_port()

        url = "ws://localhost:%d/echo" % port
        # Subdomains should be disallowed by default.  If we could pass a
        # resolver to websocket_connect we could test sibling domains as well.
        headers = {"Origin": "http://subtenant.localhost"}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers))

        self.assertEqual(cm.exception.code, 403)

    @gen_test
    def test_subprotocols(self):
        ws = yield self.ws_connect(
            "/subprotocol", subprotocols=["badproto", "goodproto"]
        )
        self.assertEqual(ws.selected_subprotocol, "goodproto")
        res = yield ws.read_message()
        self.assertEqual(res, "subprotocol=goodproto")

    @gen_test
    def test_subprotocols_not_offered(self):
        ws = yield self.ws_connect("/subprotocol")
        self.assertIs(ws.selected_subprotocol, None)
        res = yield ws.read_message()
        self.assertEqual(res, "subprotocol=None")

    @gen_test
    def test_open_coroutine(self):
        self.message_sent = Event()
        ws = yield self.ws_connect("/open_coroutine")
        yield ws.write_message("hello")
        self.message_sent.set()
        res = yield ws.read_message()
        self.assertEqual(res, "ok")

    @gen_test
    def test_error_in_open(self):
        with ExpectLog(app_log, "Uncaught exception"):
            ws = yield self.ws_connect("/error_in_open")
            res = yield ws.read_message()
        self.assertIsNone(res)

    @gen_test
    def test_error_in_async_open(self):
        with ExpectLog(app_log, "Uncaught exception"):
            ws = yield self.ws_connect("/error_in_async_open")
            res = yield ws.read_message()
        self.assertIsNone(res)
コード例 #30
0
ファイル: tornado.py プロジェクト: FloFaber/Sudoku
class DBusRouter:
    def __init__(self, conn: DBusConnection):
        self.conn = conn
        self._replies = ReplyMatcher()
        self._filters = MessageFilters()
        self._stop_receiving = Event()
        IOLoop.current().add_callback(self._receiver)

        # For backwards compatibility - old-style signal callbacks
        self.router = Router(Future)

    async def send(self, message, *, serial=None):
        await self.conn.send(message, serial=serial)

    async def send_and_get_reply(self, message):
        check_replyable(message)
        if self._stop_receiving.is_set():
            raise RouterClosed("This DBusRouter has stopped")

        serial = next(self.conn.outgoing_serial)

        with self._replies.catch(serial, Future()) as reply_fut:
            await self.send(message, serial=serial)
            return (await reply_fut)

    def filter(self, rule, *, queue: Optional[Queue] = None, bufsize=1):
        """Create a filter for incoming messages

        Usage::

            with router.filter(rule) as queue:
                matching_msg = await queue.get()

        :param jeepney.MatchRule rule: Catch messages matching this rule
        :param tornado.queues.Queue queue: Matched messages will be added to this
        :param int bufsize: If no queue is passed in, create one with this size
        """
        return FilterHandle(self._filters, rule, queue or Queue(bufsize))

    def stop(self):
        self._stop_receiving.set()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.stop()
        return False

    # Backwards compatible interface (from old DBusConnection) --------

    @property
    def unique_name(self):
        return self.conn.unique_name

    async def send_message(self, message: Message):
        if (message.header.message_type == MessageType.method_return and
                not (message.header.flags & MessageFlag.no_reply_expected)):
            return unwrap_msg(await self.send_and_get_reply(message))
        else:
            await self.send(message)

    # Code to run in receiver task ------------------------------------

    def _dispatch(self, msg: Message):
        """Handle one received message"""
        if self._replies.dispatch(msg):
            return

        for filter in self._filters.matches(msg):
            try:
                filter.queue.put_nowait(msg)
            except QueueFull:
                pass

    async def _receiver(self):
        """Receiver loop - runs in a separate task"""
        try:
            while True:
                for coro in as_completed(
                    [self.conn.receive(),
                     self._stop_receiving.wait()]):
                    msg = await coro
                    if msg is None:
                        return  # Stopped
                    self._dispatch(msg)
                    self.router.incoming(msg)
        finally:
            self.is_running = False
            # Send errors to any tasks still waiting for a message.
            self._replies.drop_all()
コード例 #31
0
ファイル: stores.py プロジェクト: mivade/tornadose
class RedisStore(BaseStore):
    """Publish data via a Redis backend.

    This data store works in a similar manner as
    :class:`DataStore`. The primary advantage is that external
    programs can be used to publish data to be consumed by clients.

    The ``channel`` keyword argument specifies which Redis channel to
    publish to and defaults to ``tornadose``.

    All remaining keyword arguments are passed directly to the
    ``redis.StrictRedis`` constructor. See `redis-py`__'s
    documentation for detais.

    New messages are read in a background thread via a
    :class:`concurrent.futures.ThreadPoolExecutor`. This requires
    either Python >= 3.2 or the backported ``futures`` module to be
    installed.

    __ https://redis-py.readthedocs.org/en/latest/

    :raises ConnectionError: when the Redis host is not pingable

    """
    def initialize(self, channel='tornadose', **kwargs):
        if redis is None:
            raise RuntimeError("The redis module is required to use RedisStore")
        self.executor = ThreadPoolExecutor(max_workers=1)
        self.channel = channel
        self.messages = Queue()
        self._done = Event()

        self._redis = redis.StrictRedis(**kwargs)
        self._redis.ping()
        self._pubsub = self._redis.pubsub(ignore_subscribe_messages=True)
        self._pubsub.subscribe(self.channel)

        self.publish()

    def submit(self, message, debug=False):
        self._redis.publish(self.channel, message)
        if debug:
            logger.debug(message)
            self._redis.setex(self.channel, 5, message)

    def shutdown(self):
        """Stop the publishing loop."""
        self._done.set()
        self.executor.shutdown(wait=False)

    @run_on_executor
    def _get_message(self):
        data = self._pubsub.get_message(timeout=1)
        if data is not None:
            data = data['data']
        return data

    @gen.coroutine
    def publish(self):
        while not self._done.is_set():
            data = yield self._get_message()
            if len(self.subscribers) > 0 and data is not None:
                [subscriber.submit(data) for subscriber in self.subscribers]
コード例 #32
0
ファイル: executor.py プロジェクト: cowlicks/distributed
class Executor(object):
    """ Distributed executor with data dependencies

    This executor resembles executors in concurrent.futures but also allows
    Futures within submit/map calls.

    Provide center address on initialization

    >>> executor = Executor(('127.0.0.1', 8787))  # doctest: +SKIP

    Use ``submit`` method like normal

    >>> a = executor.submit(add, 1, 2)  # doctest: +SKIP
    >>> b = executor.submit(add, 10, 20)  # doctest: +SKIP

    Additionally, provide results of submit calls (futures) to further submit
    calls:

    >>> c = executor.submit(add, a, b)  # doctest: +SKIP

    This allows for the dynamic creation of complex dependencies.
    """
    def __init__(self, center, start=True, delete_batch_time=1):
        self.center = coerce_to_rpc(center)
        self.futures = dict()
        self.refcount = defaultdict(lambda: 0)
        self.dask = dict()
        self.restrictions = dict()
        self.loop = IOLoop()
        self.report_queue = Queue()
        self.scheduler_queue = Queue()
        self._shutdown_event = Event()
        self._delete_batch_time = delete_batch_time

        if start:
            self.start()

    def start(self):
        """ Start scheduler running in separate thread """
        from threading import Thread
        self.loop.add_callback(self._go)
        self._loop_thread = Thread(target=self.loop.start)
        self._loop_thread.start()

    def __enter__(self):
        if not self.loop._running:
            self.start()
        return self

    def __exit__(self, type, value, traceback):
        self.shutdown()

    def _inc_ref(self, key):
        self.refcount[key] += 1

    def _dec_ref(self, key):
        self.refcount[key] -= 1
        if self.refcount[key] == 0:
            del self.refcount[key]
            self._release_key(key)

    def _release_key(self, key):
        """ Release key from distributed memory """
        self.futures[key]['event'].clear()
        logger.debug("Release key %s", key)
        del self.futures[key]
        self.scheduler_queue.put_nowait({
            'op': 'release-held-data',
            'key': key
        })

    @gen.coroutine
    def report(self):
        """ Listen to scheduler """
        while True:
            msg = yield self.report_queue.get()
            if msg['op'] == 'close':
                break
            if msg['op'] == 'task-finished':
                if msg['key'] in self.futures:
                    self.futures[msg['key']]['status'] = 'finished'
                    self.futures[msg['key']]['event'].set()
            if msg['op'] == 'lost-data':
                if msg['key'] in self.futures:
                    self.futures[msg['key']]['status'] = 'lost'
                    self.futures[msg['key']]['event'].clear()
            if msg['op'] == 'task-erred':
                if msg['key'] in self.futures:
                    self.futures[msg['key']]['status'] = 'error'
                    self.futures[msg['key']]['event'].set()

    @gen.coroutine
    def _shutdown(self):
        """ Send shutdown signal and wait until _go completes """
        self.report_queue.put_nowait({'op': 'close'})
        self.scheduler_queue.put_nowait({'op': 'close'})
        yield self._shutdown_event.wait()

    def shutdown(self):
        """ Send shutdown signal and wait until scheduler terminates """
        self.report_queue.put_nowait({'op': 'close'})
        self.scheduler_queue.put_nowait({'op': 'close'})
        self.loop.stop()
        self._loop_thread.join()

    @gen.coroutine
    def _go(self):
        """ Setup and run all other coroutines.  Block until finished. """
        self.who_has, self.has_what, self.ncores = yield [
            self.center.who_has(),
            self.center.has_what(),
            self.center.ncores()
        ]
        self.waiting = {}
        self.processing = {}
        self.stacks = {}

        worker_queues = {worker: Queue() for worker in self.ncores}
        delete_queue = Queue()

        coroutines = ([
            self.report(),
            scheduler(self.scheduler_queue, self.report_queue, worker_queues,
                      delete_queue, self.who_has, self.has_what, self.ncores,
                      self.dask, self.restrictions, self.waiting, self.stacks,
                      self.processing),
            delete(self.scheduler_queue, delete_queue, self.center.ip,
                   self.center.port, self._delete_batch_time)
        ] + [
            worker(self.scheduler_queue, worker_queues[w], w, n)
            for w, n in self.ncores.items()
        ])

        results = yield All(coroutines)
        self._shutdown_event.set()

    def submit(self, func, *args, **kwargs):
        """ Submit a function application to the scheduler

        Parameters
        ----------
        func: callable
        *args:
        **kwargs:
        pure: bool (defaults to True)
            Whether or not the function is pure.  Set ``pure=False`` for
            impure functions like ``np.random.random``.
        workers: set, iterable of sets
            A set of worker hostnames on which computations may be performed.
            Leave empty to default to all workers (common case)

        Examples
        --------
        >>> c = executor.submit(add, a, b)  # doctest: +SKIP

        Returns
        -------
        Future

        See Also
        --------
        distributed.executor.Executor.submit:
        """
        if not callable(func):
            raise TypeError(
                "First input to submit must be a callable function")

        key = kwargs.pop('key', None)
        pure = kwargs.pop('pure', True)
        workers = kwargs.pop('workers', None)

        if key is None:
            if pure:
                key = funcname(func) + '-' + tokenize(func, kwargs, *args)
            else:
                key = funcname(func) + '-' + next(tokens)

        if key in self.futures:
            return Future(key, self)

        if kwargs:
            task = (apply, func, args, kwargs)
        else:
            task = (func, ) + args

        if workers is not None:
            restrictions = {key: workers}
        else:
            restrictions = {}

        if key not in self.futures:
            self.futures[key] = {'event': Event(), 'status': 'waiting'}

        logger.debug("Submit %s(...), %s", funcname(func), key)
        self.scheduler_queue.put_nowait({
            'op': 'update-graph',
            'dsk': {
                key: task
            },
            'keys': [key],
            'restrictions': restrictions
        })

        return Future(key, self)

    def map(self, func, *iterables, **kwargs):
        """ Map a function on a sequence of arguments

        Arguments can be normal objects or Futures

        Parameters
        ----------
        func: callable
        iterables: Iterables
        pure: bool (defaults to True)
            Whether or not the function is pure.  Set ``pure=False`` for
            impure functions like ``np.random.random``.
        workers: set, iterable of sets
            A set of worker hostnames on which computations may be performed.
            Leave empty to default to all workers (common case)

        Examples
        --------
        >>> L = executor.map(func, sequence)  # doctest: +SKIP

        Returns
        -------
        list of futures

        See also
        --------
        distributed.executor.Executor.submit
        """
        pure = kwargs.pop('pure', True)
        workers = kwargs.pop('workers', None)
        if not callable(func):
            raise TypeError("First input to map must be a callable function")
        iterables = [list(it) for it in iterables]
        if pure:
            keys = [
                funcname(func) + '-' + tokenize(func, kwargs, *args)
                for args in zip(*iterables)
            ]
        else:
            uid = str(uuid.uuid4())
            keys = [
                funcname(func) + '-' + uid + '-' + next(tokens)
                for i in range(min(map(len, iterables)))
            ]

        if not kwargs:
            dsk = {
                key: (func, ) + args
                for key, args in zip(keys, zip(*iterables))
            }
        else:
            dsk = {
                key: (apply, func, args, kwargs)
                for key, args in zip(keys, zip(*iterables))
            }

        for key in dsk:
            if key not in self.futures:
                self.futures[key] = {'event': Event(), 'status': 'waiting'}

        if isinstance(workers, (list, set)):
            if workers and isinstance(first(workers), (list, set)):
                if len(workers) != len(keys):
                    raise ValueError("You only provided %d worker restrictions"
                                     " for a sequence of length %d" %
                                     (len(workers), len(keys)))
                restrictions = dict(zip(keys, workers))
            else:
                restrictions = {key: workers for key in keys}
        elif workers is None:
            restrictions = {}
        else:
            raise TypeError("Workers must be a list or set of workers or None")

        logger.debug("map(%s, ...)", funcname(func))
        self.scheduler_queue.put_nowait({
            'op': 'update-graph',
            'dsk': dsk,
            'keys': keys,
            'restrictions': restrictions
        })

        return [Future(key, self) for key in keys]

    @gen.coroutine
    def _gather(self, futures):
        futures2, keys = unpack_remotedata(futures)
        keys = list(keys)

        while True:
            yield All([self.futures[key]['event'].wait() for key in keys])
            try:
                data = yield _gather(self.center, keys)
            except KeyError as e:
                self.scheduler_queue.put_nowait({
                    'op': 'missing-data',
                    'missing': e.args
                })
                for key in e.args:
                    self.futures[key]['event'].clear()
            else:
                break

        data = dict(zip(keys, data))

        result = pack_data(futures2, data)
        raise gen.Return(result)

    def gather(self, futures):
        """ Gather futures from distributed memory

        Accepts a future or any nested core container of futures

        Examples
        --------
        >>> from operator import add  # doctest: +SKIP
        >>> e = Executor('127.0.0.1:8787')  # doctest: +SKIP
        >>> x = e.submit(add, 1, 2)  # doctest: +SKIP
        >>> e.gather(x)  # doctest: +SKIP
        3
        >>> e.gather([x, [x], x])  # doctest: +SKIP
        [3, [3], 3]
        """
        return sync(self.loop, self._gather, futures)

    @gen.coroutine
    def _get(self, dsk, keys, restrictions=None):
        flatkeys = list(flatten(keys))
        for key in flatkeys:
            if key not in self.futures:
                self.futures[key] = {'event': Event(), 'status': None}
        futures = {key: Future(key, self) for key in flatkeys}

        self.scheduler_queue.put_nowait({
            'op': 'update-graph',
            'dsk': dsk,
            'keys': flatkeys,
            'restrictions': restrictions or {}
        })

        packed = pack_data(keys, futures)
        result = yield self._gather(packed)
        raise gen.Return(result)

    def get(self, dsk, keys, **kwargs):
        """ Gather futures from distributed memory

        Parameters
        ----------
        dsk: dict
        keys: object, or nested lists of objects
        restrictions: dict (optional)
            A mapping of {key: {set of worker hostnames}} that restricts where
            jobs can take place

        Examples
        --------
        >>> from operator import add  # doctest: +SKIP
        >>> e = Executor('127.0.0.1:8787')  # doctest: +SKIP
        >>> e.get({'x': (add, 1, 2)}, 'x')  # doctest: +SKIP
        3
        """
        return sync(self.loop, self._get, dsk, keys, **kwargs)
コード例 #33
0
ファイル: game.py プロジェクト: marksteve/slash-onenight
class Game(object):

    roles = Enum('Role', [
        'werewolf',
        'seer',
        'robber',
        'troublemaker',
        'villager',
    ])
    ROLES_LABEL = {
        roles.werewolf: ':wolf: Werewolf',
        roles.seer: ':crystal_ball: Seer',
        roles.robber: ':gun: Robber',
        roles.troublemaker: ':smiling_imp: Troublemaker',
        roles.villager: ':man: Villager',
    }

    GAME_STARTING = 'Starting game...'
    CHECKING_PLAYERS = 'Checking players...'
    INVALID_PLAYERS_LENGTH = 'You can only have 3-10 players in this channel ' \
                             'to start a game!'
    GAME_STARTED = 'Everyone, pretend to close your eyes.'

    CENTER_1 = ':black_joker: First card'
    CENTER_2 = ':black_joker: Second card'
    CENTER_3 = ':black_joker: Third card'

    LOOK_OWN_CARD = ':black_joker: Everyone, look at your own card.'
    LOOK_OWN_CARD_ACTION = 'Look'
    LOOK_OWN_CARD_REVEAL = 'You are a {}'

    WEREWOLF_WAKE_UP = ':wolf: Werewolves, wake up and look for other ' \
                       'werewolves.'
    WEREWOLF_ATTACHMENT = 'If you are a werewolf...'
    WEREWOLF_LOOK_FOR_OTHERS = 'Look for others'
    WEREWOLF_LONE = 'You are the lone wolf'
    WEREWOLF_LONE_LOOKED = 'You already looked at a center card'
    WEREWOLF_NOT_LONE = 'You are not the lone wolf'
    WEREWOLF_LIST = 'The other werewolves are: {}'
    WEREWOLF_LONE_ATTACHMENT = 'If you are the lone wolf, check one of the ' \
                               'center cards...'
    WEREWOLF_LOOK_AT_CENTER = 'The {} is a {}'
    WEREWOLF_FALSE = 'You are not a werewolf!'

    SEER_WAKE_UP = ':crystal_ball: Seer, wake up. You make look at another ' \
                   'player\'s card or two of the center cards.'

    def __init__(self, bot_user_id, bot_access_token, channel_id, options):
        self.id = str(uuid4())
        self.bot_user_id = bot_user_id
        self.bot_access_token = bot_access_token
        self.channel_id = channel_id

        self.redis = toredis.Client()
        self.redis.connect(host=options.redis_host)

        self.pubsub = toredis.Client()
        self.pubsub.connect(host=options.redis_host)
        self.pubsub.subscribe(self.id, callback=self.on_button)

    def api(self, path, **kwargs):
        data = kwargs.get('data', {})
        data.setdefault('token', self.bot_access_token)
        kwargs.update(data=data)
        resp = requests.post(
            'https://slack.com/api/{}'.format(path),
            **kwargs).json()
        if not resp['ok']:
            raise RuntimeError(repr(resp['error']))
        return resp

    def start(self):
        resp = self.api('rtm.start')
        conn_future = websocket_connect(
            resp['url'], on_message_callback=self.on_message)
        ioloop = IOLoop.current()
        ioloop.add_future(conn_future, self.on_connect)

    def send(self, msg):
        evt = {
            'type': 'message',
            'channel': self.channel_id,
            'text': msg,
        }
        logging.info('Send: {}'.format(evt))
        self.conn.write_message(json.dumps(evt))

    def on_connect(self, conn_future):
        self.conn = conn_future.result()
        self.send(self.GAME_STARTING)
        players = self.get_players()
        if not (3 <= len(players) <= 5):
            self.send(self.INVALID_PLAYERS_LENGTH)
            return
        roles = self.get_roles(players)
        center = list(range(3))
        self.player_roles = list(zip(players + center, roles))
        logging.info(repr(self.player_roles))
        ioloop = IOLoop.current()
        ioloop.add_callback(self.start_night)
        self.send(self.GAME_STARTED)

    def on_message(self, msg):
        evt = json.loads(msg)
        error = evt.get('error', None)
        if error:
            logging.warning('Error: {}'.format(evt['error']))
            return
        evt_type = evt.get('type', None)
        handler = getattr(self, 'handle_{}'.format(evt_type), None)
        if handler:
            handler(evt)
        else:
            logging.debug('Unhandled event: {}'.format(evt))

    def on_button(self, resp):
        resp_type, callback_id, payload = resp
        if resp_type != 'message':
            return

        data = json.loads(payload)
        user = data['user']
        actions = data['actions']
        callback_id = data['callback_id']
        response_url = data['response_url']

        _, evt, _ = callback_id.split(':')

        if evt == 'look_own_card':
            self.on_look_own_card(user, response_url)
        elif evt == 'werewolf_look_for_others':
            self.on_werewolf_look_for_others(user, response_url)
        elif evt == 'werewolf_look_at_center':
            self.on_werewolf_look_at_center(user, actions, response_url)
        else:
            logging.warning('Unhandled button: {}', evt)

    def get_players(self):
        self.send(self.CHECKING_PLAYERS)
        channel_type = 'channel' if self.channel_id.startswith('C') \
            else 'group'
        resp = self.api(
            '{}s.info'.format(channel_type), data={'channel': self.channel_id})
        channel_info = resp[channel_type]
        players = list(filter(
            lambda m: m != self.bot_user_id, channel_info['members']))
        return players

    def get_roles(self, players):
        roles = [self.roles.werewolf] * 2 \
            + [self.roles.seer, self.roles.robber, self.roles.troublemaker] \
            + [self.roles.villager] * (len(players) - 2)
        random.shuffle(roles)
        return roles

    def get_player_ids(self):
        return list(map(lambda p: p[0], filter(
            lambda p: type(p[0]) != int, self.player_roles)))

    def get_werewolf_ids(self):
        return list(map(lambda p: p[0], filter(
            lambda p: p[1] == self.roles.werewolf, self.player_roles)))

    def get_player_werewolf_ids(self):
        return list(filter(lambda w: type(w) != int, self.get_werewolf_ids()))

    @gen.coroutine
    def start_night(self):
        self.look_own_card_done = Event()
        self.werewolves_wake_up_done = Event()
        yield [
            self.look_own_card(),
            self.werewolves_wake_up(),
            self.seer_wake_up(),
        ]

    @gen.coroutine
    def look_own_card(self):
        self.api('chat.postMessage', data={
            'channel': self.channel_id,
            'text': self.LOOK_OWN_CARD,
            'attachments': json.dumps([
                {
                    'text': None,
                    'callback_id': key('look_own_card', self.id),
                    'actions': [
                        {
                            'name': 'look_own_card',
                            'text': self.LOOK_OWN_CARD_ACTION,
                            'type': 'button',
                        },
                    ],
                },
            ]),
        })

    def on_look_own_card(self, user, response_url):
        role = dict(self.player_roles).get(user['id'])
        text = self.LOOK_OWN_CARD_REVEAL.format(self.ROLES_LABEL[role])
        requests.post(response_url, json={
            'text': text,
            'replace_original': False,
            'response_type': 'ephemeral',
        })
        player_ids = self.get_player_ids()
        look_own_key = key('look_own_players', self.id)
        self.redis.sadd(look_own_key, user['id'])

        # DEBUGGING
        self.look_own_card_done.set()

        def check_look_own(look_own_players):
            if not look_own_players:
                return

            # Check if all player werewolves have finished
            # checking on other werewolves
            for p in player_ids:
                if p not in look_own_players:
                    break
            else:
                self.look_own_card_done.set()

        self.redis.smembers(look_own_key, callback=check_look_own)

    @gen.coroutine
    def werewolves_wake_up(self):
        yield self.look_own_card_done.wait()
        # Allow werewolves to check fellow werewolves
        look_for_others_cb_id = key(
            'werewolf_look_for_others', self.id)
        look_at_center_cb_id = key(
            'werewolf_look_at_center', self.id)
        self.api('chat.postMessage', data={
            'channel': self.channel_id,
            'text': self.WEREWOLF_WAKE_UP,
            'attachments': json.dumps([
                {
                    'text': self.WEREWOLF_ATTACHMENT,
                    'callback_id': look_for_others_cb_id,
                    'actions': [
                        {
                            'name': 'werewolf_look_for_others',
                            'text': self.WEREWOLF_LOOK_FOR_OTHERS,
                            'type': 'button',
                        },
                    ],
                },
                {
                    'text': self.WEREWOLF_LONE_ATTACHMENT,
                    'callback_id': look_at_center_cb_id,
                    'actions': [
                        {
                            'name': 'center_1',
                            'value': 0,
                            'text': self.CENTER_1,
                            'type': 'button',
                        },
                        {
                            'name': 'center_2',
                            'value': 1,
                            'text': self.CENTER_2,
                            'type': 'button',
                        },
                        {
                            'name': 'center_3',
                            'value': 2,
                            'text': self.CENTER_3,
                            'type': 'button',
                        },
                    ],
                },
            ]),
        })
        player_werewolf_ids = self.get_player_werewolf_ids()
        if len(player_werewolf_ids) == 0:
            ioloop = IOLoop.current()
            ioloop.call_later(10, lambda: self.werewolves_wake_up_done.set())

    def on_werewolf_look_for_others(self, user, response_url):
        player_werewolf_ids = self.get_player_werewolf_ids()
        awake_key = key('awake_player_werewolves', self.id)

        # Check if user is an actual werewolf
        if user['id'] in player_werewolf_ids:
            if len(player_werewolf_ids) == 1:
                requests.post(response_url, json={
                    'text': self.WEREWOLF_LONE,
                    'replace_original': False,
                    'response_type': 'ephemeral',
                })
                return
            tags = map(
                lambda w: '<@{}>'.format(w),
                filter(lambda w: w != user['id'], player_werewolf_ids))
            text = self.WEREWOLF_LIST.format(', '.join(tags))
            self.redis.sadd(awake_key, user['id'])
        else:
            text = self.WEREWOLF_FALSE
        requests.post(response_url, json={
            'text': text,
            'replace_original': False,
            'response_type': 'ephemeral',
        })

        def check_awake(awake_player_werewolves):
            if not awake_player_werewolves:
                return
            # Check if all player werewolves have finished
            # checking on other werewolves
            for p in player_werewolf_ids:
                if p not in awake_player_werewolves:
                    break
            else:
                self.werewolves_wake_up_done.set()

        self.redis.smembers(awake_key, callback=check_awake)

    def on_werewolf_look_at_center(self, user, actions, response_url):
        player_werewolf_ids = self.get_player_werewolf_ids()
        lone_key = key('lone_wolf_looked', self.id)

        def check_looked(looked):
            if looked:
                requests.post(response_url, json={
                    'text': self.WEREWOLF_LONE_LOOKED,
                    'replace_original': False,
                    'response_type': 'ephemeral',
                })
                return
            action = actions[0]
            if user['id'] in player_werewolf_ids:
                if len(player_werewolf_ids) != 1:
                    requests.post(response_url, json={
                        'text': self.WEREWOLF_NOT_LONE,
                        'replace_original': False,
                        'response_type': 'ephemeral',
                    })
                    return
                chosen_center = int(action['value'])
                role = dict(self.player_roles).get(chosen_center)
                card_label = [self.CENTER_1,
                              self.CENTER_2,
                              self.CENTER_3][chosen_center]
                text = self.WEREWOLF_LOOK_AT_CENTER.format(
                    card_label, self.ROLES_LABEL[role])
                self.redis.set(lone_key, chosen_center)
                self.werewolves_wake_up_done.set()
            else:
                text = self.WEREWOLF_FALSE
            requests.post(response_url, json={
                'text': text,
                'replace_original': False,
                'response_type': 'ephemeral',
            })

        self.redis.exists(lone_key, callback=check_looked)

    @gen.coroutine
    def seer_wake_up(self):
        yield self.werewolves_wake_up_done.wait()
        self.api('chat.postMessage', data={
            'channel': self.channel_id,
            'text': self.SEER_WAKE_UP,
        })
コード例 #34
0
class ZMQDrain(object):
    """Implementation of IDrain that pushes to a zmq.Socket asynchronously.
    This implementation overrides the high-water mark behavior from
    cs.eyrie.vassal.Vassal to instead use a zmq.Poller.
    """

    def __init__(self, logger, loop, zmq_socket,
                 metric_prefix='emitter'):
        self.emitter = zmq_socket
        self.logger = logger
        self.loop = loop
        self.metric_prefix = metric_prefix
        self.output_error = Event()
        self.state = RUNNING
        self._writable = Event()
        self.sender_tag = 'sender:%s.%s' % (self.__class__.__module__,
                                            self.__class__.__name__)

    def _handle_events(self, fd, events):
        if events & self.loop.ERROR:
            self.logger.error('Error polling socket for writability')
        elif events & self.loop.WRITE:
            self.loop.remove_handler(self.emitter)
            self._writable.set()

    @gen.coroutine
    def _poll(self):
        self.loop.add_handler(self.emitter,
                              self._handle_events,
                              self.loop.WRITE)
        yield self._writable.wait()
        self._writable.clear()

    @gen.coroutine
    def close(self, timeout=None):
        self.state = CLOSING
        self.logger.debug("Flushing send queue")
        self.emitter.close()

    def emit_nowait(self, msg):
        self.logger.debug("Drain emitting")
        if isinstance(msg, basestring):
            msg = [msg]
        try:
            self.emitter.send_multipart(msg, zmq.NOBLOCK)
        except zmq.Again:
            raise QueueFull()

    @gen.coroutine
    def emit(self, msg, retry_timeout=INITIAL_TIMEOUT):
        if isinstance(msg, basestring):
            msg = [msg]
        while True:
            # This should ensure the ZMQ socket can accept more data
            yield self._poll()
            try:
                self.emitter.send_multipart(msg, zmq.NOBLOCK)
            except zmq.Again:
                # But sometimes it's not enough
                self.logger.debug('Error polling for socket writability')
                retry_timeout = min(retry_timeout*2, MAX_TIMEOUT)
                yield gen.sleep(retry_timeout.total_seconds())
            else:
                break
コード例 #35
0
class PollingLock(object):
    """ Acquires a lock by writing to a key. This is suitable for a leader
      election in cases where some downtime and initial acquisition delay is
      acceptable.

      Unlike ZooKeeper and etcd, FoundationDB does not have a way
      to specify that a key should be automatically deleted if a client does
      not heartbeat at a regular interval. This implementation requires the
      leader to update the key at regular intervals to indicate that it is
      still alive. All the other lock candidates check at a longer interval to
      see if the leader has stopped updating the key.

      Since client timestamps are unreliable, candidates do not know the
      absolute time the key was updated. Therefore, they each wait for the full
      timeout interval before checking the key again.
  """
    # The number of seconds to wait before trying to claim the lease.
    _LEASE_TIMEOUT = 60

    # The number of seconds to wait before updating the lease.
    _HEARTBEAT_INTERVAL = int(_LEASE_TIMEOUT / 10)

    def __init__(self, db, tornado_fdb, key):
        self.key = key
        self._db = db
        self._tornado_fdb = tornado_fdb

        self._client_id = uuid.uuid4()
        self._owner = None
        self._op_id = None
        self._deadline = None
        self._event = Event()

    @property
    def acquired(self):
        if self._deadline is None:
            return False

        return (self._owner == self._client_id
                and monotonic.monotonic() < self._deadline)

    def start(self):
        IOLoop.current().spawn_callback(self._run)

    @gen.coroutine
    def acquire(self):
        # Since there is no automatic event timeout, the condition is checked
        # before every acquisition.
        if not self.acquired:
            self._event.clear()

        yield self._event.wait()

    @gen.coroutine
    def _run(self):
        while True:
            try:
                yield self._acquire_lease()
            except Exception:
                logger.exception(u'Unable to acquire lease')
                yield gen.sleep(random.random() * 20)

    @gen.coroutine
    def _acquire_lease(self):
        tr = self._db.create_transaction()
        lease_value = yield self._tornado_fdb.get(tr, self.key)

        if lease_value.present():
            self._owner, new_op_id = fdb.tuple.unpack(lease_value)
            if new_op_id != self._op_id:
                self._deadline = monotonic.monotonic() + self._LEASE_TIMEOUT
                self._op_id = new_op_id
        else:
            self._owner = None

        can_acquire = self._owner is None or monotonic.monotonic(
        ) > self._deadline
        if can_acquire or self._owner == self._client_id:
            op_id = uuid.uuid4()
            tr[self.key] = fdb.tuple.pack((self._client_id, op_id))
            try:
                yield self._tornado_fdb.commit(tr, convert_exceptions=False)
            except fdb.FDBError as fdb_error:
                if fdb_error.code != FDBErrorCodes.NOT_COMMITTED:
                    raise

                # If there was a conflict, try to acquire again later.
                yield gen.sleep(random.random() * 20)
                return

            self._owner = self._client_id
            self._op_id = op_id
            self._deadline = monotonic.monotonic() + self._LEASE_TIMEOUT
            self._event.set()
            if can_acquire:
                logger.info(u'Acquired lock for {!r}'.format(self.key))

            yield gen.sleep(self._HEARTBEAT_INTERVAL)
            return

        # Since another candidate holds the lock, wait until it might expire.
        yield gen.sleep(max(self._deadline - monotonic.monotonic(), 0))
コード例 #36
0
ファイル: nanny.py プロジェクト: vas28r13/distributed
class WorkerProcess(object):
    def __init__(self, worker_args, worker_kwargs, worker_start_args,
                 silence_logs, on_exit, worker):
        self.status = 'init'
        self.silence_logs = silence_logs
        self.worker_args = worker_args
        self.worker_kwargs = worker_kwargs
        self.worker_start_args = worker_start_args
        self.on_exit = on_exit
        self.process = None
        self.Worker = worker

        # Initialized when worker is ready
        self.worker_dir = None
        self.worker_address = None

    @gen.coroutine
    def start(self):
        """
        Ensure the worker process is started.
        """
        enable_proctitle_on_children()
        if self.status == 'running':
            raise gen.Return(self.status)
        if self.status == 'starting':
            yield self.running.wait()
            raise gen.Return(self.status)

        self.init_result_q = init_q = mp_context.Queue()
        self.child_stop_q = mp_context.Queue()
        uid = uuid.uuid4().hex

        self.process = AsyncProcess(
            target=self._run,
            kwargs=dict(worker_args=self.worker_args,
                        worker_kwargs=self.worker_kwargs,
                        worker_start_args=self.worker_start_args,
                        silence_logs=self.silence_logs,
                        init_result_q=self.init_result_q,
                        child_stop_q=self.child_stop_q,
                        uid=uid,
                        Worker=self.Worker),
        )
        self.process.daemon = True
        self.process.set_exit_callback(self._on_exit)
        self.running = Event()
        self.stopped = Event()
        self.status = 'starting'
        yield self.process.start()
        msg = yield self._wait_until_connected(uid)
        if not msg:
            raise gen.Return(self.status)
        self.worker_address = msg['address']
        self.worker_dir = msg['dir']
        assert self.worker_address
        self.status = 'running'
        self.running.set()

        init_q.close()

        raise gen.Return(self.status)

    def _on_exit(self, proc):
        if proc is not self.process:
            # Ignore exit of old process instance
            return
        self.mark_stopped()

    def _death_message(self, pid, exitcode):
        assert exitcode is not None
        if exitcode == 255:
            return "Worker process %d was killed by unknown signal" % (pid, )
        elif exitcode >= 0:
            return "Worker process %d exited with status %d" % (
                pid,
                exitcode,
            )
        else:
            return "Worker process %d was killed by signal %d" % (
                pid,
                -exitcode,
            )

    def is_alive(self):
        return self.process is not None and self.process.is_alive()

    @property
    def pid(self):
        return (self.process.pid
                if self.process and self.process.is_alive() else None)

    def mark_stopped(self):
        if self.status != 'stopped':
            r = self.process.exitcode
            assert r is not None
            if r != 0:
                msg = self._death_message(self.process.pid, r)
                logger.warning(msg)
            self.status = 'stopped'
            self.stopped.set()
            # Release resources
            self.process.close()
            self.init_result_q = None
            self.child_stop_q = None
            self.process = None
            # Best effort to clean up worker directory
            if self.worker_dir and os.path.exists(self.worker_dir):
                shutil.rmtree(self.worker_dir, ignore_errors=True)
            self.worker_dir = None
            # User hook
            if self.on_exit is not None:
                self.on_exit(r)

    @gen.coroutine
    def kill(self, timeout=2, executor_wait=True):
        """
        Ensure the worker process is stopped, waiting at most
        *timeout* seconds before terminating it abruptly.
        """
        loop = IOLoop.current()
        deadline = loop.time() + timeout

        if self.status == 'stopped':
            return
        if self.status == 'stopping':
            yield self.stopped.wait()
            return
        assert self.status in ('starting', 'running')
        self.status = 'stopping'

        process = self.process
        self.child_stop_q.put({
            'op': 'stop',
            'timeout': max(0, deadline - loop.time()) * 0.8,
            'executor_wait': executor_wait,
        })
        self.child_stop_q.close()

        while process.is_alive() and loop.time() < deadline:
            yield gen.sleep(0.05)

        if process.is_alive():
            logger.warning(
                "Worker process still alive after %d seconds, killing",
                timeout)
            try:
                yield process.terminate()
            except Exception as e:
                logger.error("Failed to kill worker process: %s", e)

    @gen.coroutine
    def _wait_until_connected(self, uid):
        delay = 0.05
        while True:
            if self.status != 'starting':
                return
            try:
                msg = self.init_result_q.get_nowait()
            except Empty:
                yield gen.sleep(delay)
                continue

            if msg['uid'] != uid:  # ensure that we didn't cross queues
                continue

            if 'exception' in msg:
                logger.error("Failed while trying to start worker process: %s",
                             msg['exception'])
                yield self.process.join()
                raise msg
            else:
                raise gen.Return(msg)

    @classmethod
    def _run(cls, worker_args, worker_kwargs, worker_start_args, silence_logs,
             init_result_q, child_stop_q, uid, Worker):  # pragma: no cover
        try:
            from dask.multiprocessing import initialize_worker_process
        except ImportError:  # old Dask version
            pass
        else:
            initialize_worker_process()

        if silence_logs:
            logger.setLevel(silence_logs)

        IOLoop.clear_instance()
        loop = IOLoop()
        loop.make_current()
        worker = Worker(*worker_args, **worker_kwargs)

        @gen.coroutine
        def do_stop(timeout=5, executor_wait=True):
            try:
                yield worker._close(report=False,
                                    nanny=False,
                                    executor_wait=executor_wait,
                                    timeout=timeout)
            finally:
                loop.stop()

        def watch_stop_q():
            """
            Wait for an incoming stop message and then stop the
            worker cleanly.
            """
            while True:
                try:
                    msg = child_stop_q.get(timeout=1000)
                except Empty:
                    pass
                else:
                    child_stop_q.close()
                    assert msg.pop('op') == 'stop'
                    loop.add_callback(do_stop, **msg)
                    break

        t = threading.Thread(target=watch_stop_q,
                             name="Nanny stop queue watch")
        t.daemon = True
        t.start()

        @gen.coroutine
        def run():
            """
            Try to start worker and inform parent of outcome.
            """
            try:
                yield worker._start(*worker_start_args)
            except Exception as e:
                logger.exception("Failed to start worker")
                init_result_q.put({'uid': uid, 'exception': e})
                init_result_q.close()
            else:
                assert worker.address
                init_result_q.put({
                    'address': worker.address,
                    'dir': worker.local_dir,
                    'uid': uid
                })
                init_result_q.close()
                yield worker.wait_until_closed()
                logger.info("Worker closed")

        try:
            loop.run_sync(run)
        except TimeoutError:
            # Loop was stopped before wait_until_closed() returned, ignore
            pass
        except KeyboardInterrupt:
            pass
コード例 #37
0
ファイル: executor.py プロジェクト: cowlicks/distributed
class Executor(object):
    """ Distributed executor with data dependencies

    This executor resembles executors in concurrent.futures but also allows
    Futures within submit/map calls.

    Provide center address on initialization

    >>> executor = Executor(('127.0.0.1', 8787))  # doctest: +SKIP

    Use ``submit`` method like normal

    >>> a = executor.submit(add, 1, 2)  # doctest: +SKIP
    >>> b = executor.submit(add, 10, 20)  # doctest: +SKIP

    Additionally, provide results of submit calls (futures) to further submit
    calls:

    >>> c = executor.submit(add, a, b)  # doctest: +SKIP

    This allows for the dynamic creation of complex dependencies.
    """
    def __init__(self, center, start=True, delete_batch_time=1):
        self.center = coerce_to_rpc(center)
        self.futures = dict()
        self.refcount = defaultdict(lambda: 0)
        self.dask = dict()
        self.restrictions = dict()
        self.loop = IOLoop()
        self.report_queue = Queue()
        self.scheduler_queue = Queue()
        self._shutdown_event = Event()
        self._delete_batch_time = delete_batch_time

        if start:
            self.start()

    def start(self):
        """ Start scheduler running in separate thread """
        from threading import Thread
        self.loop.add_callback(self._go)
        self._loop_thread = Thread(target=self.loop.start)
        self._loop_thread.start()

    def __enter__(self):
        if not self.loop._running:
            self.start()
        return self

    def __exit__(self, type, value, traceback):
        self.shutdown()

    def _inc_ref(self, key):
        self.refcount[key] += 1

    def _dec_ref(self, key):
        self.refcount[key] -= 1
        if self.refcount[key] == 0:
            del self.refcount[key]
            self._release_key(key)

    def _release_key(self, key):
        """ Release key from distributed memory """
        self.futures[key]['event'].clear()
        logger.debug("Release key %s", key)
        del self.futures[key]
        self.scheduler_queue.put_nowait({'op': 'release-held-data',
                                         'key': key})

    @gen.coroutine
    def report(self):
        """ Listen to scheduler """
        while True:
            msg = yield self.report_queue.get()
            if msg['op'] == 'close':
                break
            if msg['op'] == 'task-finished':
                if msg['key'] in self.futures:
                    self.futures[msg['key']]['status'] = 'finished'
                    self.futures[msg['key']]['event'].set()
            if msg['op'] == 'lost-data':
                if msg['key'] in self.futures:
                    self.futures[msg['key']]['status'] = 'lost'
                    self.futures[msg['key']]['event'].clear()
            if msg['op'] == 'task-erred':
                if msg['key'] in self.futures:
                    self.futures[msg['key']]['status'] = 'error'
                    self.futures[msg['key']]['event'].set()

    @gen.coroutine
    def _shutdown(self):
        """ Send shutdown signal and wait until _go completes """
        self.report_queue.put_nowait({'op': 'close'})
        self.scheduler_queue.put_nowait({'op': 'close'})
        yield self._shutdown_event.wait()

    def shutdown(self):
        """ Send shutdown signal and wait until scheduler terminates """
        self.report_queue.put_nowait({'op': 'close'})
        self.scheduler_queue.put_nowait({'op': 'close'})
        self.loop.stop()
        self._loop_thread.join()

    @gen.coroutine
    def _go(self):
        """ Setup and run all other coroutines.  Block until finished. """
        self.who_has, self.has_what, self.ncores = yield [self.center.who_has(),
                                                         self.center.has_what(),
                                                         self.center.ncores()]
        self.waiting = {}
        self.processing = {}
        self.stacks = {}

        worker_queues = {worker: Queue() for worker in self.ncores}
        delete_queue = Queue()

        coroutines = ([
            self.report(),
            scheduler(self.scheduler_queue, self.report_queue, worker_queues,
                      delete_queue, self.who_has, self.has_what, self.ncores,
                      self.dask, self.restrictions, self.waiting, self.stacks,
                      self.processing),
            delete(self.scheduler_queue, delete_queue,
                   self.center.ip, self.center.port, self._delete_batch_time)]
         + [worker(self.scheduler_queue, worker_queues[w], w, n)
            for w, n in self.ncores.items()])

        results = yield All(coroutines)
        self._shutdown_event.set()

    def submit(self, func, *args, **kwargs):
        """ Submit a function application to the scheduler

        Parameters
        ----------
        func: callable
        *args:
        **kwargs:
        pure: bool (defaults to True)
            Whether or not the function is pure.  Set ``pure=False`` for
            impure functions like ``np.random.random``.
        workers: set, iterable of sets
            A set of worker hostnames on which computations may be performed.
            Leave empty to default to all workers (common case)

        Examples
        --------
        >>> c = executor.submit(add, a, b)  # doctest: +SKIP

        Returns
        -------
        Future

        See Also
        --------
        distributed.executor.Executor.submit:
        """
        if not callable(func):
            raise TypeError("First input to submit must be a callable function")

        key = kwargs.pop('key', None)
        pure = kwargs.pop('pure', True)
        workers = kwargs.pop('workers', None)

        if key is None:
            if pure:
                key = funcname(func) + '-' + tokenize(func, kwargs, *args)
            else:
                key = funcname(func) + '-' + next(tokens)

        if key in self.futures:
            return Future(key, self)

        if kwargs:
            task = (apply, func, args, kwargs)
        else:
            task = (func,) + args

        if workers is not None:
            restrictions = {key: workers}
        else:
            restrictions = {}

        if key not in self.futures:
            self.futures[key] = {'event': Event(), 'status': 'waiting'}

        logger.debug("Submit %s(...), %s", funcname(func), key)
        self.scheduler_queue.put_nowait({'op': 'update-graph',
                                         'dsk': {key: task},
                                         'keys': [key],
                                         'restrictions': restrictions})

        return Future(key, self)

    def map(self, func, *iterables, **kwargs):
        """ Map a function on a sequence of arguments

        Arguments can be normal objects or Futures

        Parameters
        ----------
        func: callable
        iterables: Iterables
        pure: bool (defaults to True)
            Whether or not the function is pure.  Set ``pure=False`` for
            impure functions like ``np.random.random``.
        workers: set, iterable of sets
            A set of worker hostnames on which computations may be performed.
            Leave empty to default to all workers (common case)

        Examples
        --------
        >>> L = executor.map(func, sequence)  # doctest: +SKIP

        Returns
        -------
        list of futures

        See also
        --------
        distributed.executor.Executor.submit
        """
        pure = kwargs.pop('pure', True)
        workers = kwargs.pop('workers', None)
        if not callable(func):
            raise TypeError("First input to map must be a callable function")
        iterables = [list(it) for it in iterables]
        if pure:
            keys = [funcname(func) + '-' + tokenize(func, kwargs, *args)
                    for args in zip(*iterables)]
        else:
            uid = str(uuid.uuid4())
            keys = [funcname(func) + '-' + uid + '-' + next(tokens)
                    for i in range(min(map(len, iterables)))]

        if not kwargs:
            dsk = {key: (func,) + args
                   for key, args in zip(keys, zip(*iterables))}
        else:
            dsk = {key: (apply, func, args, kwargs)
                   for key, args in zip(keys, zip(*iterables))}

        for key in dsk:
            if key not in self.futures:
                self.futures[key] = {'event': Event(), 'status': 'waiting'}

        if isinstance(workers, (list, set)):
            if workers and isinstance(first(workers), (list, set)):
                if len(workers) != len(keys):
                    raise ValueError("You only provided %d worker restrictions"
                    " for a sequence of length %d" % (len(workers), len(keys)))
                restrictions = dict(zip(keys, workers))
            else:
                restrictions = {key: workers for key in keys}
        elif workers is None:
            restrictions = {}
        else:
            raise TypeError("Workers must be a list or set of workers or None")

        logger.debug("map(%s, ...)", funcname(func))
        self.scheduler_queue.put_nowait({'op': 'update-graph',
                                         'dsk': dsk,
                                         'keys': keys,
                                         'restrictions': restrictions})

        return [Future(key, self) for key in keys]

    @gen.coroutine
    def _gather(self, futures):
        futures2, keys = unpack_remotedata(futures)
        keys = list(keys)

        while True:
            yield All([self.futures[key]['event'].wait() for key in keys])
            try:
                data = yield _gather(self.center, keys)
            except KeyError as e:
                self.scheduler_queue.put_nowait({'op': 'missing-data',
                                                 'missing': e.args})
                for key in e.args:
                    self.futures[key]['event'].clear()
            else:
                break

        data = dict(zip(keys, data))

        result = pack_data(futures2, data)
        raise gen.Return(result)

    def gather(self, futures):
        """ Gather futures from distributed memory

        Accepts a future or any nested core container of futures

        Examples
        --------
        >>> from operator import add  # doctest: +SKIP
        >>> e = Executor('127.0.0.1:8787')  # doctest: +SKIP
        >>> x = e.submit(add, 1, 2)  # doctest: +SKIP
        >>> e.gather(x)  # doctest: +SKIP
        3
        >>> e.gather([x, [x], x])  # doctest: +SKIP
        [3, [3], 3]
        """
        return sync(self.loop, self._gather, futures)

    @gen.coroutine
    def _get(self, dsk, keys, restrictions=None):
        flatkeys = list(flatten(keys))
        for key in flatkeys:
            if key not in self.futures:
                self.futures[key] = {'event': Event(), 'status': None}
        futures = {key: Future(key, self) for key in flatkeys}

        self.scheduler_queue.put_nowait({'op': 'update-graph',
                                         'dsk': dsk,
                                         'keys': flatkeys,
                                         'restrictions': restrictions or {}})

        packed = pack_data(keys, futures)
        result = yield self._gather(packed)
        raise gen.Return(result)

    def get(self, dsk, keys, **kwargs):
        """ Gather futures from distributed memory

        Parameters
        ----------
        dsk: dict
        keys: object, or nested lists of objects
        restrictions: dict (optional)
            A mapping of {key: {set of worker hostnames}} that restricts where
            jobs can take place

        Examples
        --------
        >>> from operator import add  # doctest: +SKIP
        >>> e = Executor('127.0.0.1:8787')  # doctest: +SKIP
        >>> e.get({'x': (add, 1, 2)}, 'x')  # doctest: +SKIP
        3
        """
        return sync(self.loop, self._get, dsk, keys, **kwargs)
コード例 #38
0
class TornadoSubscriptionManager(SubscriptionManager):
    def __init__(self, pubnub_instance):

        subscription_manager = self

        self._message_queue = Queue()
        self._consumer_event = Event()
        self._cancellation_event = Event()
        self._subscription_lock = Semaphore(1)
        # self._current_request_key_object = None
        self._heartbeat_periodic_callback = None
        self._reconnection_manager = TornadoReconnectionManager(pubnub_instance)

        super(TornadoSubscriptionManager, self).__init__(pubnub_instance)
        self._start_worker()

        class TornadoReconnectionCallback(ReconnectionCallback):
            def on_reconnect(self):
                subscription_manager.reconnect()

                pn_status = PNStatus()
                pn_status.category = PNStatusCategory.PNReconnectedCategory
                pn_status.error = False

                subscription_manager._subscription_status_announced = True
                subscription_manager._listener_manager.announce_status(pn_status)

        self._reconnection_listener = TornadoReconnectionCallback()
        self._reconnection_manager.set_reconnection_listener(self._reconnection_listener)

    def _set_consumer_event(self):
        self._consumer_event.set()

    def _message_queue_put(self, message):
        self._message_queue.put(message)

    def _start_worker(self):
        self._consumer = TornadoSubscribeMessageWorker(self._pubnub,
                                                       self._listener_manager,
                                                       self._message_queue,
                                                       self._consumer_event)
        run = stack_context.wrap(self._consumer.run)
        self._pubnub.ioloop.spawn_callback(run)

    def reconnect(self):
        self._should_stop = False
        self._pubnub.ioloop.spawn_callback(self._start_subscribe_loop)
        # self._register_heartbeat_timer()

    def disconnect(self):
        self._should_stop = True
        self._stop_heartbeat_timer()
        self._stop_subscribe_loop()

    @tornado.gen.coroutine
    def _start_subscribe_loop(self):
        self._stop_subscribe_loop()

        yield self._subscription_lock.acquire()

        self._cancellation_event.clear()

        combined_channels = self._subscription_state.prepare_channel_list(True)
        combined_groups = self._subscription_state.prepare_channel_group_list(True)

        if len(combined_channels) == 0 and len(combined_groups) == 0:
            return

        envelope_future = Subscribe(self._pubnub) \
            .channels(combined_channels).channel_groups(combined_groups) \
            .timetoken(self._timetoken).region(self._region) \
            .filter_expression(self._pubnub.config.filter_expression) \
            .cancellation_event(self._cancellation_event) \
            .future()

        canceller_future = self._cancellation_event.wait()

        wi = tornado.gen.WaitIterator(envelope_future, canceller_future)

        # iterates 2 times: one for result one for cancelled
        while not wi.done():
            try:
                result = yield wi.next()
            except Exception as e:
                # TODO: verify the error will not be eaten
                logger.error(e)
                raise
            else:
                if wi.current_future == envelope_future:
                    e = result
                elif wi.current_future == canceller_future:
                    return
                else:
                    raise Exception("Unexpected future resolved: %s" % str(wi.current_future))

                if e.is_error():
                    # 599 error doesn't works - tornado use this status code
                    # for a wide range of errors, for ex:
                    # HTTP Server Error (599): [Errno -2] Name or service not known
                    if e.status is not None and e.status.category == PNStatusCategory.PNTimeoutCategory:
                        self._pubnub.ioloop.spawn_callback(self._start_subscribe_loop)
                        return

                    logger.error("Exception in subscribe loop: %s" % str(e))

                    if e.status is not None and e.status.category == PNStatusCategory.PNAccessDeniedCategory:
                        e.status.operation = PNOperationType.PNUnsubscribeOperation

                    self._listener_manager.announce_status(e.status)

                    self._reconnection_manager.start_polling()
                    self.disconnect()
                    return
                else:
                    self._handle_endpoint_call(e.result, e.status)

                    self._pubnub.ioloop.spawn_callback(self._start_subscribe_loop)

            finally:
                self._cancellation_event.set()
                yield tornado.gen.moment
                self._subscription_lock.release()
                self._cancellation_event.clear()
                break

    def _stop_subscribe_loop(self):
        if self._cancellation_event is not None and not self._cancellation_event.is_set():
            self._cancellation_event.set()

    def _stop_heartbeat_timer(self):
        if self._heartbeat_periodic_callback is not None:
            self._heartbeat_periodic_callback.stop()

    def _register_heartbeat_timer(self):
        super(TornadoSubscriptionManager, self)._register_heartbeat_timer()
        self._heartbeat_periodic_callback = PeriodicCallback(
            stack_context.wrap(self._perform_heartbeat_loop),
            self._pubnub.config.heartbeat_interval * TornadoSubscriptionManager.HEARTBEAT_INTERVAL_MULTIPLIER,
            self._pubnub.ioloop)
        self._heartbeat_periodic_callback.start()

    @tornado.gen.coroutine
    def _perform_heartbeat_loop(self):
        if self._heartbeat_call is not None:
            # TODO: cancel call
            pass

        cancellation_event = Event()
        state_payload = self._subscription_state.state_payload()
        presence_channels = self._subscription_state.prepare_channel_list(False)
        presence_groups = self._subscription_state.prepare_channel_group_list(False)

        if len(presence_channels) == 0 and len(presence_groups) == 0:
            return

        try:
            envelope = yield self._pubnub.heartbeat() \
                .channels(presence_channels) \
                .channel_groups(presence_groups) \
                .state(state_payload) \
                .cancellation_event(cancellation_event) \
                .future()

            heartbeat_verbosity = self._pubnub.config.heartbeat_notification_options
            if envelope.status.is_error:
                if heartbeat_verbosity == PNHeartbeatNotificationOptions.ALL or \
                        heartbeat_verbosity == PNHeartbeatNotificationOptions.ALL:
                    self._listener_manager.announce_status(envelope.status)
            else:
                if heartbeat_verbosity == PNHeartbeatNotificationOptions.ALL:
                    self._listener_manager.announce_status(envelope.status)

        except PubNubTornadoException:
            pass
            # TODO: check correctness
            # if e.status is not None and e.status.category == PNStatusCategory.PNTimeoutCategory:
            #     self._start_subscribe_loop()
            # else:
            #     self._listener_manager.announce_status(e.status)
        except Exception as e:
            print(e)
        finally:
            cancellation_event.set()

    @tornado.gen.coroutine
    def _send_leave(self, unsubscribe_operation):
        envelope = yield Leave(self._pubnub) \
            .channels(unsubscribe_operation.channels) \
            .channel_groups(unsubscribe_operation.channel_groups).future()
        self._listener_manager.announce_status(envelope.status)
コード例 #39
0
class KeepAliveTest(AsyncHTTPTestCase):
    """Tests various scenarios for HTTP 1.1 keep-alive support.

    These tests don't use AsyncHTTPClient because we want to control
    connection reuse and closing.
    """
    def get_app(self):
        class HelloHandler(RequestHandler):
            def get(self):
                self.finish("Hello world")

            def post(self):
                self.finish("Hello world")

        class LargeHandler(RequestHandler):
            def get(self):
                # 512KB should be bigger than the socket buffers so it will
                # be written out in chunks.
                self.write("".join(chr(i % 256) * 1024 for i in range(512)))

        class FinishOnCloseHandler(RequestHandler):
            def initialize(self, cleanup_event):
                self.cleanup_event = cleanup_event

            @gen.coroutine
            def get(self):
                self.flush()
                yield self.cleanup_event.wait()

            def on_connection_close(self):
                # This is not very realistic, but finishing the request
                # from the close callback has the right timing to mimic
                # some errors seen in the wild.
                self.finish("closed")

        self.cleanup_event = Event()
        return Application([
            ("/", HelloHandler),
            ("/large", LargeHandler),
            (
                "/finish_on_close",
                FinishOnCloseHandler,
                dict(cleanup_event=self.cleanup_event),
            ),
        ])

    def setUp(self):
        super(KeepAliveTest, self).setUp()
        self.http_version = b"HTTP/1.1"

    def tearDown(self):
        # We just closed the client side of the socket; let the IOLoop run
        # once to make sure the server side got the message.
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

        if hasattr(self, "stream"):
            self.stream.close()
        super(KeepAliveTest, self).tearDown()

    # The next few methods are a crude manual http client
    @gen.coroutine
    def connect(self):
        self.stream = IOStream(socket.socket())
        yield self.stream.connect(("10.0.0.7", self.get_http_port()))

    @gen.coroutine
    def read_headers(self):
        first_line = yield self.stream.read_until(b"\r\n")
        self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
        header_bytes = yield self.stream.read_until(b"\r\n\r\n")
        headers = HTTPHeaders.parse(header_bytes.decode("latin1"))
        raise gen.Return(headers)

    @gen.coroutine
    def read_response(self):
        self.headers = yield self.read_headers()
        body = yield self.stream.read_bytes(int(
            self.headers["Content-Length"]))
        self.assertEqual(b"Hello world", body)

    def close(self):
        self.stream.close()
        del self.stream

    @gen_test
    def test_two_requests(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        self.close()

    @gen_test
    def test_request_close(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
        yield self.read_response()
        data = yield self.stream.read_until_close()
        self.assertTrue(not data)
        self.assertEqual(self.headers["Connection"], "close")
        self.close()

    # keepalive is supported for http 1.0 too, but it's opt-in
    @gen_test
    def test_http10(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
        yield self.read_response()
        data = yield self.stream.read_until_close()
        self.assertTrue(not data)
        self.assertTrue("Connection" not in self.headers)
        self.close()

    @gen_test
    def test_http10_keepalive(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()

    @gen_test
    def test_http10_keepalive_extra_crlf(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(
            b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()

    @gen_test
    def test_pipelined_requests(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        yield self.read_response()
        self.close()

    @gen_test
    def test_pipelined_cancel(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
        # only read once
        yield self.read_response()
        self.close()

    @gen_test
    def test_cancel_during_download(self):
        yield self.connect()
        self.stream.write(b"GET /large HTTP/1.1\r\n\r\n")
        yield self.read_headers()
        yield self.stream.read_bytes(1024)
        self.close()

    @gen_test
    def test_finish_while_closed(self):
        yield self.connect()
        self.stream.write(b"GET /finish_on_close HTTP/1.1\r\n\r\n")
        yield self.read_headers()
        self.close()
        # Let the hanging coroutine clean up after itself
        self.cleanup_event.set()

    @gen_test
    def test_keepalive_chunked(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"POST / HTTP/1.0\r\n"
                          b"Connection: keep-alive\r\n"
                          b"Transfer-Encoding: chunked\r\n"
                          b"\r\n"
                          b"0\r\n"
                          b"\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()
コード例 #40
0
class WebUpdater:
    def __init__(self, config, cmd_helper):
        self.server = cmd_helper.get_server()
        self.cmd_helper = cmd_helper
        self.repo = config.get('repo').strip().strip("/")
        self.owner, self.name = self.repo.split("/", 1)
        if hasattr(config, "get_name"):
            self.name = config.get_name().split()[-1]
        self.path = os.path.realpath(os.path.expanduser(config.get("path")))
        self.persistent_files = []
        pfiles = config.get('persistent_files', None)
        if pfiles is not None:
            self.persistent_files = [
                pf.strip().strip("/") for pf in pfiles.split("\n")
                if pf.strip()
            ]
            if ".version" in self.persistent_files:
                raise config.error(
                    "Invalid value for option 'persistent_files': "
                    "'.version' can not be persistent")

        self.version = self.remote_version = self.dl_url = "?"
        self.etag = None
        self.init_evt = Event()
        self.refresh_condition = None
        self._get_local_version()
        logging.info(f"\nInitializing Client Updater: '{self.name}',"
                     f"\nversion: {self.version}"
                     f"\npath: {self.path}")

    def _get_local_version(self):
        version_path = os.path.join(self.path, ".version")
        if os.path.isfile(os.path.join(self.path, ".version")):
            with open(version_path, "r") as f:
                v = f.read()
            self.version = v.strip()

    async def check_initialized(self, timeout=None):
        if self.init_evt.is_set():
            return
        if timeout is not None:
            timeout = IOLoop.current().time() + timeout
        await self.init_evt.wait(timeout)

    async def refresh(self):
        if self.refresh_condition is None:
            self.refresh_condition = Condition()
        else:
            self.refresh_condition.wait()
            return
        try:
            self._get_local_version()
            await self._get_remote_version()
        except Exception:
            logging.exception("Error Refreshing Client")
        self.init_evt.set()
        self.refresh_condition.notify_all()
        self.refresh_condition = None

    async def _get_remote_version(self):
        # Remote state
        url = f"https://api.github.com/repos/{self.repo}/releases/latest"
        try:
            result = await self.cmd_helper.github_api_request(url,
                                                              etag=self.etag)
        except Exception:
            logging.exception(f"Client {self.repo}: Github Request Error")
            result = {}
        if result is None:
            # No change, update not necessary
            return
        self.etag = result.get('etag', None)
        self.remote_version = result.get('name', "?")
        release_assets = result.get('assets', [{}])[0]
        self.dl_url = release_assets.get('browser_download_url', "?")
        logging.info(f"Github client Info Received:\nRepo: {self.name}\n"
                     f"Local Version: {self.version}\n"
                     f"Remote Version: {self.remote_version}\n"
                     f"url: {self.dl_url}")

    async def update(self, *args):
        await self.check_initialized(20.)
        if self.refresh_condition is not None:
            # wait for refresh if in progess
            self.refresh_condition.wait()
        if self.remote_version == "?":
            await self.refresh()
            if self.remote_version == "?":
                raise self.server.error(
                    f"Client {self.repo}: Unable to locate update")
        if self.dl_url == "?":
            raise self.server.error(
                f"Client {self.repo}: Invalid download url")
        if self.version == self.remote_version:
            # Already up to date
            return
        self.cmd_helper.notify_update_response(
            f"Downloading Client: {self.name}")
        archive = await self.cmd_helper.http_download_request(self.dl_url)
        with tempfile.TemporaryDirectory(suffix=self.name,
                                         prefix="client") as tempdir:
            if os.path.isdir(self.path):
                # find and move persistent files
                for fname in os.listdir(self.path):
                    src_path = os.path.join(self.path, fname)
                    if fname in self.persistent_files:
                        dest_dir = os.path.dirname(os.path.join(
                            tempdir, fname))
                        os.makedirs(dest_dir, exist_ok=True)
                        shutil.move(src_path, dest_dir)
                shutil.rmtree(self.path)
            os.mkdir(self.path)
            with zipfile.ZipFile(io.BytesIO(archive)) as zf:
                zf.extractall(self.path)
            # Move temporary files back into
            for fname in os.listdir(tempdir):
                src_path = os.path.join(tempdir, fname)
                dest_dir = os.path.dirname(os.path.join(self.path, fname))
                os.makedirs(dest_dir, exist_ok=True)
                shutil.move(src_path, dest_dir)
        self.version = self.remote_version
        version_path = os.path.join(self.path, ".version")
        if not os.path.exists(version_path):
            with open(version_path, "w") as f:
                f.write(self.version)
        self.cmd_helper.notify_update_response(
            f"Client Update Finished: {self.name}", is_complete=True)

    def get_update_status(self):
        return {
            'name': self.name,
            'owner': self.owner,
            'version': self.version,
            'remote_version': self.remote_version
        }
コード例 #41
0
ファイル: queues.py プロジェクト: bdarnell/tornado
class Queue(Generic[_T]):
    """Coordinate producer and consumer coroutines.

    If maxsize is 0 (the default) the queue size is unbounded.

    .. testcode::

        from tornado import gen
        from tornado.ioloop import IOLoop
        from tornado.queues import Queue

        q = Queue(maxsize=2)

        async def consumer():
            async for item in q:
                try:
                    print('Doing work on %s' % item)
                    await gen.sleep(0.01)
                finally:
                    q.task_done()

        async def producer():
            for item in range(5):
                await q.put(item)
                print('Put %s' % item)

        async def main():
            # Start consumer without waiting (since it never finishes).
            IOLoop.current().spawn_callback(consumer)
            await producer()     # Wait for producer to put all tasks.
            await q.join()       # Wait for consumer to finish all tasks.
            print('Done')

        IOLoop.current().run_sync(main)

    .. testoutput::

        Put 0
        Put 1
        Doing work on 0
        Put 2
        Doing work on 1
        Put 3
        Doing work on 2
        Put 4
        Doing work on 3
        Doing work on 4
        Done


    In versions of Python without native coroutines (before 3.5),
    ``consumer()`` could be written as::

        @gen.coroutine
        def consumer():
            while True:
                item = yield q.get()
                try:
                    print('Doing work on %s' % item)
                    yield gen.sleep(0.01)
                finally:
                    q.task_done()

    .. versionchanged:: 4.3
       Added ``async for`` support in Python 3.5.

    """

    # Exact type depends on subclass. Could be another generic
    # parameter and use protocols to be more precise here.
    _queue = None  # type: Any

    def __init__(self, maxsize: int = 0) -> None:
        if maxsize is None:
            raise TypeError("maxsize can't be None")

        if maxsize < 0:
            raise ValueError("maxsize can't be negative")

        self._maxsize = maxsize
        self._init()
        self._getters = collections.deque([])  # type: Deque[Future[_T]]
        self._putters = collections.deque([])  # type: Deque[Tuple[_T, Future[None]]]
        self._unfinished_tasks = 0
        self._finished = Event()
        self._finished.set()

    @property
    def maxsize(self) -> int:
        """Number of items allowed in the queue."""
        return self._maxsize

    def qsize(self) -> int:
        """Number of items in the queue."""
        return len(self._queue)

    def empty(self) -> bool:
        return not self._queue

    def full(self) -> bool:
        if self.maxsize == 0:
            return False
        else:
            return self.qsize() >= self.maxsize

    def put(
        self, item: _T, timeout: Union[float, datetime.timedelta] = None
    ) -> "Future[None]":
        """Put an item into the queue, perhaps waiting until there is room.

        Returns a Future, which raises `tornado.util.TimeoutError` after a
        timeout.

        ``timeout`` may be a number denoting a time (on the same
        scale as `tornado.ioloop.IOLoop.time`, normally `time.time`), or a
        `datetime.timedelta` object for a deadline relative to the
        current time.
        """
        future = Future()  # type: Future[None]
        try:
            self.put_nowait(item)
        except QueueFull:
            self._putters.append((item, future))
            _set_timeout(future, timeout)
        else:
            future.set_result(None)
        return future

    def put_nowait(self, item: _T) -> None:
        """Put an item into the queue without blocking.

        If no free slot is immediately available, raise `QueueFull`.
        """
        self._consume_expired()
        if self._getters:
            assert self.empty(), "queue non-empty, why are getters waiting?"
            getter = self._getters.popleft()
            self.__put_internal(item)
            future_set_result_unless_cancelled(getter, self._get())
        elif self.full():
            raise QueueFull
        else:
            self.__put_internal(item)

    def get(self, timeout: Union[float, datetime.timedelta] = None) -> Awaitable[_T]:
        """Remove and return an item from the queue.

        Returns an awaitable which resolves once an item is available, or raises
        `tornado.util.TimeoutError` after a timeout.

        ``timeout`` may be a number denoting a time (on the same
        scale as `tornado.ioloop.IOLoop.time`, normally `time.time`), or a
        `datetime.timedelta` object for a deadline relative to the
        current time.

        .. note::

           The ``timeout`` argument of this method differs from that
           of the standard library's `queue.Queue.get`. That method
           interprets numeric values as relative timeouts; this one
           interprets them as absolute deadlines and requires
           ``timedelta`` objects for relative timeouts (consistent
           with other timeouts in Tornado).

        """
        future = Future()  # type: Future[_T]
        try:
            future.set_result(self.get_nowait())
        except QueueEmpty:
            self._getters.append(future)
            _set_timeout(future, timeout)
        return future

    def get_nowait(self) -> _T:
        """Remove and return an item from the queue without blocking.

        Return an item if one is immediately available, else raise
        `QueueEmpty`.
        """
        self._consume_expired()
        if self._putters:
            assert self.full(), "queue not full, why are putters waiting?"
            item, putter = self._putters.popleft()
            self.__put_internal(item)
            future_set_result_unless_cancelled(putter, None)
            return self._get()
        elif self.qsize():
            return self._get()
        else:
            raise QueueEmpty

    def task_done(self) -> None:
        """Indicate that a formerly enqueued task is complete.

        Used by queue consumers. For each `.get` used to fetch a task, a
        subsequent call to `.task_done` tells the queue that the processing
        on the task is complete.

        If a `.join` is blocking, it resumes when all items have been
        processed; that is, when every `.put` is matched by a `.task_done`.

        Raises `ValueError` if called more times than `.put`.
        """
        if self._unfinished_tasks <= 0:
            raise ValueError("task_done() called too many times")
        self._unfinished_tasks -= 1
        if self._unfinished_tasks == 0:
            self._finished.set()

    def join(self, timeout: Union[float, datetime.timedelta] = None) -> Awaitable[None]:
        """Block until all items in the queue are processed.

        Returns an awaitable, which raises `tornado.util.TimeoutError` after a
        timeout.
        """
        return self._finished.wait(timeout)

    def __aiter__(self) -> _QueueIterator[_T]:
        return _QueueIterator(self)

    # These three are overridable in subclasses.
    def _init(self) -> None:
        self._queue = collections.deque()

    def _get(self) -> _T:
        return self._queue.popleft()

    def _put(self, item: _T) -> None:
        self._queue.append(item)

    # End of the overridable methods.

    def __put_internal(self, item: _T) -> None:
        self._unfinished_tasks += 1
        self._finished.clear()
        self._put(item)

    def _consume_expired(self) -> None:
        # Remove timed-out waiters.
        while self._putters and self._putters[0][1].done():
            self._putters.popleft()

        while self._getters and self._getters[0].done():
            self._getters.popleft()

    def __repr__(self) -> str:
        return "<%s at %s %s>" % (type(self).__name__, hex(id(self)), self._format())

    def __str__(self) -> str:
        return "<%s %s>" % (type(self).__name__, self._format())

    def _format(self) -> str:
        result = "maxsize=%r" % (self.maxsize,)
        if getattr(self, "_queue", None):
            result += " queue=%r" % self._queue
        if self._getters:
            result += " getters[%s]" % len(self._getters)
        if self._putters:
            result += " putters[%s]" % len(self._putters)
        if self._unfinished_tasks:
            result += " tasks=%s" % self._unfinished_tasks
        return result
コード例 #42
0
class GitUpdater:
    def __init__(self, config, cmd_helper, path=None, env=None):
        self.server = cmd_helper.get_server()
        self.cmd_helper = cmd_helper
        self.name = config.get_name().split()[-1]
        if path is None:
            path = os.path.expanduser(config.get('path'))
        self.repo_path = path
        self.repo = GitRepo(cmd_helper, path, self.name)
        self.init_evt = Event()
        self.debug = self.cmd_helper.is_debug_enabled()
        self.env = config.get("env", env)
        dist_packages = None
        self.python_reqs = None
        if self.env is not None:
            self.env = os.path.expanduser(self.env)
            dist_packages = config.get('python_dist_packages', None)
            self.python_reqs = os.path.join(self.repo_path,
                                            config.get("requirements"))
        self.origin = config.get("origin").lower()
        self.install_script = config.get('install_script', None)
        if self.install_script is not None:
            self.install_script = os.path.abspath(
                os.path.join(self.repo_path, self.install_script))
        self.venv_args = config.get('venv_args', None)
        self.python_dist_packages = None
        self.python_dist_path = None
        self.env_package_path = None
        if dist_packages is not None:
            self.python_dist_packages = [
                p.strip() for p in dist_packages.split('\n') if p.strip()
            ]
            self.python_dist_path = os.path.abspath(
                config.get('python_dist_path'))
            env_package_path = os.path.abspath(
                os.path.join(os.path.dirname(self.env), "..",
                             config.get('env_package_path')))
            matches = glob.glob(env_package_path)
            if len(matches) == 1:
                self.env_package_path = matches[0]
            else:
                raise config.error("No match for 'env_package_path': %s" %
                                   (env_package_path, ))
        for opt in [
                "repo_path", "env", "python_reqs", "install_script",
                "python_dist_path", "env_package_path"
        ]:
            val = getattr(self, opt)
            if val is None:
                continue
            if not os.path.exists(val):
                raise config.error("Invalid path for option '%s': %s" %
                                   (val, opt))

    def _get_version_info(self):
        ver_path = os.path.join(self.repo_path, "scripts/version.txt")
        vinfo = {}
        if os.path.isfile(ver_path):
            data = ""
            with open(ver_path, 'r') as f:
                data = f.read()
            try:
                entries = [e.strip() for e in data.split('\n') if e.strip()]
                vinfo = dict([i.split('=') for i in entries])
                vinfo = {
                    k: tuple(re.findall(r"\d+", v))
                    for k, v in vinfo.items()
                }
            except Exception:
                pass
            else:
                self._log_info(f"Version Info Found: {vinfo}")
        vinfo['version'] = self.repo.get_version()
        return vinfo

    def _log_exc(self, msg, traceback=True):
        log_msg = f"Repo {self.name}: {msg}"
        if traceback:
            logging.exception(log_msg)
        else:
            logging.info(log_msg)
        return self.server.error(msg)

    def _log_info(self, msg):
        log_msg = f"Repo {self.name}: {msg}"
        logging.info(log_msg)

    def _notify_status(self, msg, is_complete=False):
        log_msg = f"Repo {self.name}: {msg}"
        logging.debug(log_msg)
        self.cmd_helper.notify_update_response(log_msg, is_complete)

    async def check_initialized(self, timeout=None):
        if self.init_evt.is_set():
            return
        if timeout is not None:
            timeout = IOLoop.current().time() + timeout
        await self.init_evt.wait(timeout)

    async def refresh(self):
        try:
            await self._update_repo_state()
        except Exception:
            logging.exception("Error Refreshing git state")
        self.init_evt.set()

    async def _update_repo_state(self, need_fetch=True):
        self.is_valid = False
        await self.repo.initialize(need_fetch=need_fetch)
        invalids = self.repo.report_invalids(self.origin)
        if invalids:
            msgs = '\n'.join(invalids)
            self._log_info(f"Repo validation checks failed:\n{msgs}")
            if self.debug:
                self.is_valid = True
                self._log_info(
                    "Repo debug enabled, overriding validity checks")
            else:
                self._log_info("Updates on repo disabled")
        else:
            self.is_valid = True
            self._log_info("Validity check for git repo passed")

    async def update(self, update_deps=False):
        await self.check_initialized(20.)
        await self.repo.wait_for_init()
        if not self.is_valid:
            raise self._log_exc("Update aborted, repo not valid", False)
        if self.repo.is_dirty():
            raise self._log_exc("Update aborted, repo has been modified",
                                False)
        if self.repo.is_current():
            # No need to update
            return
        self._notify_status("Updating Repo...")
        try:
            if self.repo.is_detached():
                await self.repo.fetch()
                await self.repo.checkout()
            else:
                await self.repo.pull()
            # Prune stale refrences.  Do this separately from pull or
            # fetch to prevent a timeout during a prune
            await self.repo.prune()
        except Exception:
            raise self._log_exc("Error running 'git pull'")
        # Check Semantic Versions
        vinfo = self._get_version_info()
        cur_version = vinfo.get('version', ())
        update_deps |= cur_version < vinfo.get('deps_version', ())
        need_env_rebuild = cur_version < vinfo.get('env_version', ())
        if update_deps:
            await self._install_packages()
            await self._update_virtualenv(need_env_rebuild)
        elif need_env_rebuild:
            await self._update_virtualenv(True)
        # Refresh local repo state
        await self._update_repo_state(need_fetch=False)
        if self.name == "moonraker":
            # Launch restart async so the request can return
            # before the server restarts
            self._notify_status("Update Finished...", is_complete=True)
            IOLoop.current().call_later(.1, self.restart_service)
        else:
            await self.restart_service()
            self._notify_status("Update Finished...", is_complete=True)

    async def _install_packages(self):
        if self.install_script is None:
            return
        # Open install file file and read
        inst_path = self.install_script
        if not os.path.isfile(inst_path):
            self._log_info(f"Unable to open install script: {inst_path}")
            return
        with open(inst_path, 'r') as f:
            data = f.read()
        packages = re.findall(r'PKGLIST="(.*)"', data)
        packages = [p.lstrip("${PKGLIST}").strip() for p in packages]
        if not packages:
            self._log_info(f"No packages found in script: {inst_path}")
            return
        # TODO: Log and notify that packages will be installed
        pkgs = " ".join(packages)
        logging.debug(f"Repo {self.name}: Detected Packages: {pkgs}")
        self._notify_status("Installing system dependencies...")
        # Install packages with apt-get
        try:
            await self.cmd_helper.run_cmd(f"{APT_CMD} update",
                                          timeout=300.,
                                          notify=True)
            await self.cmd_helper.run_cmd(f"{APT_CMD} install --yes {pkgs}",
                                          timeout=3600.,
                                          notify=True)
        except Exception:
            self._log_exc("Error updating packages via apt-get")
            return

    async def _update_virtualenv(self, rebuild_env=False):
        if self.env is None:
            return
        # Update python dependencies
        bin_dir = os.path.dirname(self.env)
        env_path = os.path.normpath(os.path.join(bin_dir, ".."))
        if rebuild_env:
            self._notify_status(f"Creating virtualenv at: {env_path}...")
            if os.path.exists(env_path):
                shutil.rmtree(env_path)
            try:
                await self.cmd_helper.run_cmd(
                    f"virtualenv {self.venv_args} {env_path}", timeout=300.)
            except Exception:
                self._log_exc(f"Error creating virtualenv")
                return
            if not os.path.exists(self.env):
                raise self._log_exc("Failed to create new virtualenv", False)
        reqs = self.python_reqs
        if not os.path.isfile(reqs):
            self._log_exc(f"Invalid path to requirements_file '{reqs}'")
            return
        pip = os.path.join(bin_dir, "pip")
        self._notify_status("Updating python packages...")
        try:
            await self.cmd_helper.run_cmd(f"{pip} install -r {reqs}",
                                          timeout=1200.,
                                          notify=True,
                                          retries=3)
        except Exception:
            self._log_exc("Error updating python requirements")
        self._install_python_dist_requirements()

    def _install_python_dist_requirements(self):
        dist_reqs = self.python_dist_packages
        if dist_reqs is None:
            return
        dist_path = self.python_dist_path
        site_path = self.env_package_path
        for pkg in dist_reqs:
            for f in os.listdir(dist_path):
                if f.startswith(pkg):
                    src = os.path.join(dist_path, f)
                    dest = os.path.join(site_path, f)
                    self._notify_status(f"Linking to dist package: {pkg}")
                    if os.path.islink(dest):
                        os.remove(dest)
                    elif os.path.exists(dest):
                        self._notify_status(
                            f"Error symlinking dist package: {pkg}, "
                            f"file already exists: {dest}")
                        continue
                    os.symlink(src, dest)
                    break

    async def restart_service(self):
        self._notify_status("Restarting Service...")
        try:
            await self.cmd_helper.run_cmd(f"sudo systemctl restart {self.name}"
                                          )
        except Exception:
            raise self._log_exc("Error restarting service")

    def get_update_status(self):
        status = self.repo.get_repo_status()
        status['is_valid'] = self.is_valid
        status['debug_enabled'] = self.debug
        return status
コード例 #43
0
class WebSocketTest(WebSocketBaseTestCase):
    def get_app(self):
        self.close_future = Future()  # type: Future[None]
        return Application(
            [
                ("/echo", EchoHandler, dict(close_future=self.close_future)),
                ("/non_ws", NonWebSocketHandler),
                ("/header", HeaderHandler,
                 dict(close_future=self.close_future)),
                (
                    "/header_echo",
                    HeaderEchoHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/close_reason",
                    CloseReasonHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/error_in_on_message",
                    ErrorInOnMessageHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/async_prepare",
                    AsyncPrepareHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/path_args/(.*)",
                    PathArgsHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/coroutine",
                    CoroutineOnMessageHandler,
                    dict(close_future=self.close_future),
                ),
                ("/render", RenderMessageHandler,
                 dict(close_future=self.close_future)),
                (
                    "/subprotocol",
                    SubprotocolHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/open_coroutine",
                    OpenCoroutineHandler,
                    dict(close_future=self.close_future, test=self),
                ),
                ("/error_in_open", ErrorInOpenHandler),
                ("/error_in_async_open", ErrorInAsyncOpenHandler),
                ("/nodelay", NoDelayHandler),
            ],
            template_loader=DictLoader(
                {"message.html": "<b>{{ message }}</b>"}),
        )

    def get_http_client(self):
        # These tests require HTTP/1; force the use of SimpleAsyncHTTPClient.
        return SimpleAsyncHTTPClient()

    def tearDown(self):
        super(WebSocketTest, self).tearDown()
        RequestHandler._template_loaders.clear()

    def test_http_request(self):
        # WS server, HTTP client.
        response = self.fetch("/echo")
        self.assertEqual(response.code, 400)

    def test_missing_websocket_key(self):
        response = self.fetch(
            "/echo",
            headers={
                "Connection": "Upgrade",
                "Upgrade": "WebSocket",
                "Sec-WebSocket-Version": "13",
            },
        )
        self.assertEqual(response.code, 400)

    def test_bad_websocket_version(self):
        response = self.fetch(
            "/echo",
            headers={
                "Connection": "Upgrade",
                "Upgrade": "WebSocket",
                "Sec-WebSocket-Version": "12",
            },
        )
        self.assertEqual(response.code, 426)

    @gen_test
    def test_websocket_gen(self):
        ws = yield self.ws_connect("/echo")
        yield ws.write_message("hello")
        response = yield ws.read_message()
        self.assertEqual(response, "hello")

    def test_websocket_callbacks(self):
        websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port(),
                          callback=self.stop)
        ws = self.wait().result()
        ws.write_message("hello")
        ws.read_message(self.stop)
        response = self.wait().result()
        self.assertEqual(response, "hello")
        self.close_future.add_done_callback(lambda f: self.stop())
        ws.close()
        self.wait()

    @gen_test
    def test_binary_message(self):
        ws = yield self.ws_connect("/echo")
        ws.write_message(b"hello \xe9", binary=True)
        response = yield ws.read_message()
        self.assertEqual(response, b"hello \xe9")

    @gen_test
    def test_unicode_message(self):
        ws = yield self.ws_connect("/echo")
        ws.write_message(u"hello \u00e9")
        response = yield ws.read_message()
        self.assertEqual(response, u"hello \u00e9")

    @gen_test
    def test_render_message(self):
        ws = yield self.ws_connect("/render")
        ws.write_message("hello")
        response = yield ws.read_message()
        self.assertEqual(response, "<b>hello</b>")

    @gen_test
    def test_error_in_on_message(self):
        ws = yield self.ws_connect("/error_in_on_message")
        ws.write_message("hello")
        with ExpectLog(app_log, "Uncaught exception"):
            response = yield ws.read_message()
        self.assertIs(response, None)

    @gen_test
    def test_websocket_http_fail(self):
        with self.assertRaises(HTTPError) as cm:
            yield self.ws_connect("/notfound")
        self.assertEqual(cm.exception.code, 404)

    @gen_test
    def test_websocket_http_success(self):
        with self.assertRaises(WebSocketError):
            yield self.ws_connect("/non_ws")

    @gen_test
    def test_websocket_network_fail(self):
        sock, port = bind_unused_port()
        sock.close()
        with self.assertRaises(IOError):
            with ExpectLog(gen_log, ".*"):
                yield websocket_connect("ws://127.0.0.1:%d/" % port,
                                        connect_timeout=3600)

    @gen_test
    def test_websocket_close_buffered_data(self):
        ws = yield websocket_connect("ws://127.0.0.1:%d/echo" %
                                     self.get_http_port())
        ws.write_message("hello")
        ws.write_message("world")
        # Close the underlying stream.
        ws.stream.close()

    @gen_test
    def test_websocket_headers(self):
        # Ensure that arbitrary headers can be passed through websocket_connect.
        ws = yield websocket_connect(
            HTTPRequest(
                "ws://127.0.0.1:%d/header" % self.get_http_port(),
                headers={"X-Test": "hello"},
            ))
        response = yield ws.read_message()
        self.assertEqual(response, "hello")

    @gen_test
    def test_websocket_header_echo(self):
        # Ensure that headers can be returned in the response.
        # Specifically, that arbitrary headers passed through websocket_connect
        # can be returned.
        ws = yield websocket_connect(
            HTTPRequest(
                "ws://127.0.0.1:%d/header_echo" % self.get_http_port(),
                headers={"X-Test-Hello": "hello"},
            ))
        self.assertEqual(ws.headers.get("X-Test-Hello"), "hello")
        self.assertEqual(ws.headers.get("X-Extra-Response-Header"),
                         "Extra-Response-Value")

    @gen_test
    def test_server_close_reason(self):
        ws = yield self.ws_connect("/close_reason")
        msg = yield ws.read_message()
        # A message of None means the other side closed the connection.
        self.assertIs(msg, None)
        self.assertEqual(ws.close_code, 1001)
        self.assertEqual(ws.close_reason, "goodbye")
        # The on_close callback is called no matter which side closed.
        code, reason = yield self.close_future
        # The client echoed the close code it received to the server,
        # so the server's close code (returned via close_future) is
        # the same.
        self.assertEqual(code, 1001)

    @gen_test
    def test_client_close_reason(self):
        ws = yield self.ws_connect("/echo")
        ws.close(1001, "goodbye")
        code, reason = yield self.close_future
        self.assertEqual(code, 1001)
        self.assertEqual(reason, "goodbye")

    @gen_test
    def test_write_after_close(self):
        ws = yield self.ws_connect("/close_reason")
        msg = yield ws.read_message()
        self.assertIs(msg, None)
        with self.assertRaises(WebSocketClosedError):
            ws.write_message("hello")

    @gen_test
    def test_async_prepare(self):
        # Previously, an async prepare method triggered a bug that would
        # result in a timeout on test shutdown (and a memory leak).
        ws = yield self.ws_connect("/async_prepare")
        ws.write_message("hello")
        res = yield ws.read_message()
        self.assertEqual(res, "hello")

    @gen_test
    def test_path_args(self):
        ws = yield self.ws_connect("/path_args/hello")
        res = yield ws.read_message()
        self.assertEqual(res, "hello")

    @gen_test
    def test_coroutine(self):
        ws = yield self.ws_connect("/coroutine")
        # Send both messages immediately, coroutine must process one at a time.
        yield ws.write_message("hello1")
        yield ws.write_message("hello2")
        res = yield ws.read_message()
        self.assertEqual(res, "hello1")
        res = yield ws.read_message()
        self.assertEqual(res, "hello2")

    @gen_test
    def test_check_origin_valid_no_path(self):
        port = self.get_http_port()

        url = "ws://127.0.0.1:%d/echo" % port
        headers = {"Origin": "http://127.0.0.1:%d" % port}

        ws = yield websocket_connect(HTTPRequest(url, headers=headers))
        ws.write_message("hello")
        response = yield ws.read_message()
        self.assertEqual(response, "hello")

    @gen_test
    def test_check_origin_valid_with_path(self):
        port = self.get_http_port()

        url = "ws://127.0.0.1:%d/echo" % port
        headers = {"Origin": "http://127.0.0.1:%d/something" % port}

        ws = yield websocket_connect(HTTPRequest(url, headers=headers))
        ws.write_message("hello")
        response = yield ws.read_message()
        self.assertEqual(response, "hello")

    @gen_test
    def test_check_origin_invalid_partial_url(self):
        port = self.get_http_port()

        url = "ws://127.0.0.1:%d/echo" % port
        headers = {"Origin": "127.0.0.1:%d" % port}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers))
        self.assertEqual(cm.exception.code, 403)

    @gen_test
    def test_check_origin_invalid(self):
        port = self.get_http_port()

        url = "ws://127.0.0.1:%d/echo" % port
        # Host is 127.0.0.1, which should not be accessible from some other
        # domain
        headers = {"Origin": "http://somewhereelse.com"}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers))

        self.assertEqual(cm.exception.code, 403)

    @gen_test
    def test_check_origin_invalid_subdomains(self):
        port = self.get_http_port()

        url = "ws://localhost:%d/echo" % port
        # Subdomains should be disallowed by default.  If we could pass a
        # resolver to websocket_connect we could test sibling domains as well.
        headers = {"Origin": "http://subtenant.localhost"}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers))

        self.assertEqual(cm.exception.code, 403)

    @gen_test
    def test_subprotocols(self):
        ws = yield self.ws_connect("/subprotocol",
                                   subprotocols=["badproto", "goodproto"])
        self.assertEqual(ws.selected_subprotocol, "goodproto")
        res = yield ws.read_message()
        self.assertEqual(res, "subprotocol=goodproto")

    @gen_test
    def test_subprotocols_not_offered(self):
        ws = yield self.ws_connect("/subprotocol")
        self.assertIs(ws.selected_subprotocol, None)
        res = yield ws.read_message()
        self.assertEqual(res, "subprotocol=None")

    @gen_test
    def test_open_coroutine(self):
        self.message_sent = Event()
        ws = yield self.ws_connect("/open_coroutine")
        yield ws.write_message("hello")
        self.message_sent.set()
        res = yield ws.read_message()
        self.assertEqual(res, "ok")

    @gen_test
    def test_error_in_open(self):
        with ExpectLog(app_log, "Uncaught exception"):
            ws = yield self.ws_connect("/error_in_open")
            res = yield ws.read_message()
        self.assertIsNone(res)

    @gen_test
    def test_error_in_async_open(self):
        with ExpectLog(app_log, "Uncaught exception"):
            ws = yield self.ws_connect("/error_in_async_open")
            res = yield ws.read_message()
        self.assertIsNone(res)

    @gen_test
    def test_nodelay(self):
        ws = yield self.ws_connect("/nodelay")
        res = yield ws.read_message()
        self.assertEqual(res, "hello")
コード例 #44
0
class UpdateManager:
    def __init__(self, config):
        self.server = config.get_server()
        self.config = config
        self.config.read_supplemental_config(SUPPLEMENTAL_CFG_PATH)
        self.repo_debug = config.getboolean('enable_repo_debug', False)
        auto_refresh_enabled = config.getboolean('enable_auto_refresh', False)
        self.distro = config.get('distro', "debian").lower()
        if self.distro not in SUPPORTED_DISTROS:
            raise config.error(f"Unsupported distro: {self.distro}")
        if self.repo_debug:
            logging.warn("UPDATE MANAGER: REPO DEBUG ENABLED")
        env = sys.executable
        mooncfg = self.config[f"update_manager static {self.distro} moonraker"]
        self.updaters = {
            "system": PackageUpdater(self),
            "moonraker": GitUpdater(self, mooncfg, MOONRAKER_PATH, env)
        }
        self.current_update = None
        # TODO: Check for client config in [update_manager].  This is
        # deprecated and will be removed.
        client_repo = config.get("client_repo", None)
        if client_repo is not None:
            client_path = config.get("client_path")
            name = client_repo.split("/")[-1]
            self.updaters[name] = WebUpdater(self, {
                'repo': client_repo,
                'path': client_path
            })
        client_sections = self.config.get_prefix_sections(
            "update_manager client")
        for section in client_sections:
            cfg = self.config[section]
            name = section.split()[-1]
            if name in self.updaters:
                raise config.error("Client repo named %s already added" %
                                   (name, ))
            client_type = cfg.get("type")
            if client_type == "git_repo":
                self.updaters[name] = GitUpdater(self, cfg)
            elif client_type == "web":
                self.updaters[name] = WebUpdater(self, cfg)
            else:
                raise config.error("Invalid type '%s' for section [%s]" %
                                   (client_type, section))

        # GitHub API Rate Limit Tracking
        self.gh_rate_limit = None
        self.gh_limit_remaining = None
        self.gh_limit_reset_time = None
        self.gh_init_evt = Event()
        self.cmd_request_lock = Lock()
        self.is_refreshing = False

        # Auto Status Refresh
        self.last_auto_update_time = 0
        self.refresh_cb = None
        if auto_refresh_enabled:
            self.refresh_cb = PeriodicCallback(self._handle_auto_refresh,
                                               UPDATE_REFRESH_INTERVAL_MS)
            self.refresh_cb.start()

        AsyncHTTPClient.configure(None, defaults=dict(user_agent="Moonraker"))
        self.http_client = AsyncHTTPClient()

        self.server.register_endpoint("/machine/update/moonraker", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/klipper", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/system", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/client", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/status", ["GET"],
                                      self._handle_status_request)

        # Register Ready Event
        self.server.register_event_handler("server:klippy_identified",
                                           self._set_klipper_repo)
        # Initialize GitHub API Rate Limits and configured updaters
        IOLoop.current().spawn_callback(self._initalize_updaters,
                                        list(self.updaters.values()))

    async def _initalize_updaters(self, initial_updaters):
        self.is_refreshing = True
        await self._init_api_rate_limit()
        for updater in initial_updaters:
            if isinstance(updater, PackageUpdater):
                ret = updater.refresh(False)
            else:
                ret = updater.refresh()
            if asyncio.iscoroutine(ret):
                await ret
        self.is_refreshing = False

    async def _set_klipper_repo(self):
        kinfo = self.server.get_klippy_info()
        if not kinfo:
            logging.info("No valid klippy info received")
            return
        kpath = kinfo['klipper_path']
        env = kinfo['python_path']
        kupdater = self.updaters.get('klipper', None)
        if kupdater is not None and kupdater.repo_path == kpath and \
                kupdater.env == env:
            # Current Klipper Updater is valid
            return
        kcfg = self.config[f"update_manager static {self.distro} klipper"]
        self.updaters['klipper'] = GitUpdater(self, kcfg, kpath, env)
        await self.updaters['klipper'].refresh()

    async def _check_klippy_printing(self):
        klippy_apis = self.server.lookup_plugin('klippy_apis')
        result = await klippy_apis.query_objects({'print_stats': None},
                                                 default={})
        pstate = result.get('print_stats', {}).get('state', "")
        return pstate.lower() == "printing"

    async def _handle_auto_refresh(self):
        if await self._check_klippy_printing():
            # Don't Refresh during a print
            logging.info("Klippy is printing, auto refresh aborted")
            return
        cur_time = time.time()
        cur_hour = time.localtime(cur_time).tm_hour
        time_diff = cur_time - self.last_auto_update_time
        # Update packages if it has been more than 12 hours
        # and the local time is between 12AM and 5AM
        if time_diff < MIN_REFRESH_TIME or cur_hour >= MAX_PKG_UPDATE_HOUR:
            # Not within the update time window
            return
        self.last_auto_update_time = cur_time
        vinfo = {}
        need_refresh_all = not self.is_refreshing
        async with self.cmd_request_lock:
            self.is_refreshing = True
            try:
                for name, updater in list(self.updaters.items()):
                    if need_refresh_all:
                        ret = updater.refresh()
                        if asyncio.iscoroutine(ret):
                            await ret
                    if hasattr(updater, "get_update_status"):
                        vinfo[name] = updater.get_update_status()
            except Exception:
                logging.exception("Unable to Refresh Status")
                return
            finally:
                self.is_refreshing = False
        uinfo = {
            'version_info': vinfo,
            'github_rate_limit': self.gh_rate_limit,
            'github_requests_remaining': self.gh_limit_remaining,
            'github_limit_reset_time': self.gh_limit_reset_time,
            'busy': self.current_update is not None
        }
        self.server.send_event("update_manager:update_refreshed", uinfo)

    async def _handle_update_request(self, web_request):
        if await self._check_klippy_printing():
            raise self.server.error("Update Refused: Klippy is printing")
        app = web_request.get_endpoint().split("/")[-1]
        if app == "client":
            app = web_request.get('name')
        inc_deps = web_request.get_boolean('include_deps', False)
        if self.current_update is not None and \
                self.current_update[0] == app:
            return f"Object {app} is currently being updated"
        updater = self.updaters.get(app, None)
        if updater is None:
            raise self.server.error(f"Updater {app} not available")
        async with self.cmd_request_lock:
            self.current_update = (app, id(web_request))
            try:
                await updater.update(inc_deps)
            except Exception as e:
                self.notify_update_response(f"Error updating {app}")
                self.notify_update_response(str(e), is_complete=True)
                raise
            finally:
                self.current_update = None
        return "ok"

    async def _handle_status_request(self, web_request):
        check_refresh = web_request.get_boolean('refresh', False)
        # Don't refresh if a print is currently in progress or
        # if an update is in progress.  Just return the current
        # state
        if self.current_update is not None or \
                await self._check_klippy_printing():
            check_refresh = False
        need_refresh = False
        if check_refresh:
            # If there is an outstanding request processing a
            # refresh, we don't need to do it again.
            need_refresh = not self.is_refreshing
            await self.cmd_request_lock.acquire()
            self.is_refreshing = True
        vinfo = {}
        try:
            for name, updater in list(self.updaters.items()):
                await updater.check_initialized(120.)
                if need_refresh:
                    ret = updater.refresh()
                    if asyncio.iscoroutine(ret):
                        await ret
                if hasattr(updater, "get_update_status"):
                    vinfo[name] = updater.get_update_status()
        except Exception:
            raise
        finally:
            if check_refresh:
                self.is_refreshing = False
                self.cmd_request_lock.release()
        return {
            'version_info': vinfo,
            'github_rate_limit': self.gh_rate_limit,
            'github_requests_remaining': self.gh_limit_remaining,
            'github_limit_reset_time': self.gh_limit_reset_time,
            'busy': self.current_update is not None
        }

    async def execute_cmd(self, cmd, timeout=10., notify=False, retries=1):
        shell_command = self.server.lookup_plugin('shell_command')
        cb = self.notify_update_response if notify else None
        scmd = shell_command.build_shell_command(cmd, callback=cb)
        while retries:
            if await scmd.run(timeout=timeout, verbose=notify):
                break
            retries -= 1
        if not retries:
            raise self.server.error("Shell Command Error")

    async def execute_cmd_with_response(self, cmd, timeout=10.):
        shell_command = self.server.lookup_plugin('shell_command')
        scmd = shell_command.build_shell_command(cmd, None)
        result = await scmd.run_with_response(timeout, retries=5)
        if result is None:
            raise self.server.error(f"Error Running Command: {cmd}")
        return result

    async def _init_api_rate_limit(self):
        url = "https://api.github.com/rate_limit"
        while 1:
            try:
                resp = await self.github_api_request(url, is_init=True)
                core = resp['resources']['core']
                self.gh_rate_limit = core['limit']
                self.gh_limit_remaining = core['remaining']
                self.gh_limit_reset_time = core['reset']
            except Exception:
                logging.exception("Error Initializing GitHub API Rate Limit")
                await tornado.gen.sleep(30.)
            else:
                reset_time = time.ctime(self.gh_limit_reset_time)
                logging.info(
                    "GitHub API Rate Limit Initialized\n"
                    f"Rate Limit: {self.gh_rate_limit}\n"
                    f"Rate Limit Remaining: {self.gh_limit_remaining}\n"
                    f"Rate Limit Reset Time: {reset_time}, "
                    f"Seconds Since Epoch: {self.gh_limit_reset_time}")
                break
        self.gh_init_evt.set()

    async def github_api_request(self, url, etag=None, is_init=False):
        if not is_init:
            timeout = time.time() + 30.
            try:
                await self.gh_init_evt.wait(timeout)
            except Exception:
                raise self.server.error("Timeout while waiting for GitHub "
                                        "API Rate Limit initialization")
        if self.gh_limit_remaining == 0:
            curtime = time.time()
            if curtime < self.gh_limit_reset_time:
                raise self.server.error(
                    f"GitHub Rate Limit Reached\nRequest: {url}\n"
                    f"Limit Reset Time: {time.ctime(self.gh_limit_remaining)}")
        headers = {"Accept": "application/vnd.github.v3+json"}
        if etag is not None:
            headers['If-None-Match'] = etag
        retries = 5
        while retries:
            try:
                timeout = time.time() + 10.
                fut = self.http_client.fetch(url,
                                             headers=headers,
                                             connect_timeout=5.,
                                             request_timeout=5.,
                                             raise_error=False)
                resp = await tornado.gen.with_timeout(timeout, fut)
            except Exception:
                retries -= 1
                msg = f"Error Processing GitHub API request: {url}"
                if not retries:
                    raise self.server.error(msg)
                logging.exception(msg)
                await tornado.gen.sleep(1.)
                continue
            etag = resp.headers.get('etag', None)
            if etag is not None:
                if etag[:2] == "W/":
                    etag = etag[2:]
            logging.info("GitHub API Request Processed\n"
                         f"URL: {url}\n"
                         f"Response Code: {resp.code}\n"
                         f"Response Reason: {resp.reason}\n"
                         f"ETag: {etag}")
            if resp.code == 403:
                raise self.server.error(
                    f"Forbidden GitHub Request: {resp.reason}")
            elif resp.code == 304:
                logging.info(f"Github Request not Modified: {url}")
                return None
            if resp.code != 200:
                retries -= 1
                if not retries:
                    raise self.server.error(
                        f"Github Request failed: {resp.code} {resp.reason}")
                logging.info(
                    f"Github request error, {retries} retries remaining")
                await tornado.gen.sleep(1.)
                continue
            # Update rate limit on return success
            if 'X-Ratelimit-Limit' in resp.headers and not is_init:
                self.gh_rate_limit = int(resp.headers['X-Ratelimit-Limit'])
                self.gh_limit_remaining = int(
                    resp.headers['X-Ratelimit-Remaining'])
                self.gh_limit_reset_time = float(
                    resp.headers['X-Ratelimit-Reset'])
            decoded = json.loads(resp.body)
            decoded['etag'] = etag
            return decoded

    async def http_download_request(self, url):
        retries = 5
        while retries:
            try:
                timeout = time.time() + 130.
                fut = self.http_client.fetch(
                    url,
                    headers={"Accept": "application/zip"},
                    connect_timeout=5.,
                    request_timeout=120.)
                resp = await tornado.gen.with_timeout(timeout, fut)
            except Exception:
                retries -= 1
                logging.exception("Error Processing Download")
                if not retries:
                    raise
                await tornado.gen.sleep(1.)
                continue
            return resp.body

    def notify_update_response(self, resp, is_complete=False):
        resp = resp.strip()
        if isinstance(resp, bytes):
            resp = resp.decode()
        notification = {
            'message': resp,
            'application': None,
            'proc_id': None,
            'complete': is_complete
        }
        if self.current_update is not None:
            notification['application'] = self.current_update[0]
            notification['proc_id'] = self.current_update[1]
        self.server.send_event("update_manager:update_response", notification)

    def close(self):
        self.http_client.close()
        if self.refresh_cb is not None:
            self.refresh_cb.stop()
コード例 #45
0
ファイル: core.py プロジェクト: indera/distributed
class ConnectionPool(object):
    """ A maximum sized pool of Tornado IOStreams

    This provides a connect method that mirrors the normal distributed.connect
    method, but provides connection sharing and tracks connection limits.

    This object provides an ``rpc`` like interface::

        >>> rpc = ConnectionPool(limit=512)
        >>> scheduler = rpc('127.0.0.1:8786')
        >>> workers = [rpc(ip=ip, port=port) for ip, port in ...]

        >>> info = yield scheduler.identity()

    It creates enough streams to satisfy concurrent connections to any
    particular address::

        >>> a, b = yield [scheduler.who_has(), scheduler.has_what()]

    It reuses existing streams so that we don't have to continuously reconnect.

    It also maintains a stream limit to avoid "too many open file handle"
    issues.  Whenever this maximum is reached we clear out all idling streams.
    If that doesn't do the trick then we wait until one of the occupied streams
    closes.
    """
    def __init__(self, limit=512):
        self.open = 0
        self.active = 0
        self.limit = limit
        self.available = defaultdict(set)
        self.occupied = defaultdict(set)
        self.event = Event()

    def __str__(self):
        return "<ConnectionPool: open=%d, active=%d>" % (self.open,
                self.active)

    __repr__ = __str__

    def __call__(self, arg=None, ip=None, port=None, addr=None):
        """ Cached rpc objects """
        ip, port = ip_port_from_args(arg=arg, addr=addr, ip=ip, port=port)
        return RPCCall(ip, port, self)

    @gen.coroutine
    def connect(self, ip, port, timeout=3):
        if self.available.get((ip, port)):
            stream = self.available[ip, port].pop()
            self.active += 1
            self.occupied[ip, port].add(stream)
            raise gen.Return(stream)

        while self.open >= self.limit:
            self.event.clear()
            self.collect()
            yield self.event.wait()

        self.open += 1
        stream = yield connect(ip=ip, port=port, timeout=timeout)
        stream.set_close_callback(lambda: self.on_close(ip, port, stream))
        self.active += 1
        self.occupied[ip, port].add(stream)

        if self.open >= self.limit:
            self.event.clear()

        raise gen.Return(stream)

    def on_close(self, ip, port, stream):
        self.open -= 1

        if stream in self.available[ip, port]:
            self.available[ip, port].remove(stream)
        if stream in self.occupied[ip, port]:
            self.occupied[ip, port].remove(stream)
            self.active -= 1

        if self.open <= self.limit:
            self.event.set()

    def collect(self):
        logger.info("Collecting unused streams.  open: %d, active: %d",
                    self.open, self.active)
        for streams in list(self.available.values()):
            for stream in streams:
                stream.close()

    def close(self):
        for streams in list(self.available.values()):
            for stream in streams:
                stream.close()
        for streams in list(self.occupied.values()):
            for stream in streams:
                stream.close()
コード例 #46
0
class PackageUpdater:
    def __init__(self, umgr):
        self.server = umgr.server
        self.execute_cmd = umgr.execute_cmd
        self.execute_cmd_with_response = umgr.execute_cmd_with_response
        self.notify_update_response = umgr.notify_update_response
        self.available_packages = []
        self.init_evt = Event()
        self.refresh_condition = None

    async def refresh(self, fetch_packages=True):
        # TODO: Use python-apt python lib rather than command line for updates
        if self.refresh_condition is None:
            self.refresh_condition = Condition()
        else:
            self.refresh_condition.wait()
            return
        try:
            if fetch_packages:
                await self.execute_cmd(f"{APT_CMD} update",
                                       timeout=300.,
                                       retries=3)
            res = await self.execute_cmd_with_response("apt list --upgradable",
                                                       timeout=60.)
            pkg_list = [p.strip() for p in res.split("\n") if p.strip()]
            if pkg_list:
                pkg_list = pkg_list[2:]
                self.available_packages = [
                    p.split("/", maxsplit=1)[0] for p in pkg_list
                ]
            pkg_list = "\n".join(self.available_packages)
            logging.info(
                f"Detected {len(self.available_packages)} package updates:"
                f"\n{pkg_list}")
        except Exception:
            logging.exception("Error Refreshing System Packages")
        self.init_evt.set()
        self.refresh_condition.notify_all()
        self.refresh_condition = None

    async def check_initialized(self, timeout=None):
        if self.init_evt.is_set():
            return
        if timeout is not None:
            timeout = IOLoop.current().time() + timeout
        await self.init_evt.wait(timeout)

    async def update(self, *args):
        await self.check_initialized(20.)
        if self.refresh_condition is not None:
            self.refresh_condition.wait()
        self.notify_update_response("Updating packages...")
        try:
            await self.execute_cmd(f"{APT_CMD} update",
                                   timeout=300.,
                                   notify=True)
            await self.execute_cmd(f"{APT_CMD} upgrade --yes",
                                   timeout=3600.,
                                   notify=True)
        except Exception:
            raise self.server.error("Error updating system packages")
        self.available_packages = []
        self.notify_update_response("Package update finished...",
                                    is_complete=True)

    def get_update_status(self):
        return {
            'package_count': len(self.available_packages),
            'package_list': self.available_packages
        }
コード例 #47
0
ファイル: tornado_kazoo.py プロジェクト: AppScale/appscale
class AsyncKazooLock(object):
  """ A lock based on kazoo.recipe.Lock and modified to work as a coroutine.
  """

  # Node name, after the contender UUID, before the sequence
  # number. Involved in read/write locks.
  _NODE_NAME = "__lock__"

  # Node names which exclude this contender when present at a lower
  # sequence number. Involved in read/write locks.
  _EXCLUDE_NAMES = ["__lock__"]

  def __init__(self, client, path, identifier=None):
    """ Creates an AsyncKazooLock.

    Args:
      client: A KazooClient.
      path: The lock path to use.
      identifier: The name to use for this lock contender. This can be useful
        for querying to see who the current lock contenders are.
    """
    self.client = client
    self.tornado_kazoo = TornadoKazoo(client)
    self.path = path

    # some data is written to the node. this can be queried via
    # contenders() to see who is contending for the lock
    self.data = str(identifier or "").encode('utf-8')
    self.node = None

    self.wake_event = AsyncEvent()

    # props to Netflix Curator for this trick. It is possible for our
    # create request to succeed on the server, but for a failure to
    # prevent us from getting back the full path name. We prefix our
    # lock name with a uuid and can check for its presence on retry.
    self.prefix = uuid.uuid4().hex + self._NODE_NAME
    self.create_path = self.path + "/" + self.prefix

    self.create_tried = False
    self.is_acquired = False
    self.assured_path = False
    self.cancelled = False
    self._retry = AsyncKazooRetry(max_tries=-1)
    self._lock = AsyncLock()

  @gen.coroutine
  def _ensure_path(self):
    yield self.tornado_kazoo.ensure_path(self.path)
    self.assured_path = True

  def cancel(self):
    """ Cancels a pending lock acquire. """
    self.cancelled = True
    self.wake_event.set()

  @gen.coroutine
  def acquire(self, timeout=None, ephemeral=True):
    """ Acquires the lock. By default, it blocks and waits forever.

    Args:
      timeout: A float specifying how long to wait to acquire the lock.
      ephemeral: A boolean indicating that the lock should use an ephemeral
        node.

    Raises:
      LockTimeout if the lock wasn't acquired within `timeout` seconds.
    """
    retry = self._retry.copy()
    retry.deadline = timeout

    # Ensure we are locked so that we avoid multiple coroutines in
    # this acquisition routine at the same time...
    timeout_interval = None
    if timeout is not None:
      timeout_interval = datetime.timedelta(seconds=timeout)

    try:
      with (yield self._lock.acquire(timeout=timeout_interval)):
        already_acquired = self.is_acquired
        gotten = False
        try:
          gotten = yield retry(self._inner_acquire, timeout=timeout,
                               ephemeral=ephemeral)
        except RetryFailedError:
          pass
        except KazooException:
          # if we did ultimately fail, attempt to clean up
          exc_info = sys.exc_info()
          if not already_acquired:
            yield self._best_effort_cleanup()
            self.cancelled = False
          six.reraise(exc_info[0], exc_info[1], exc_info[2])
        if gotten:
          self.is_acquired = gotten
        if not gotten and not already_acquired:
          yield self._best_effort_cleanup()
        raise gen.Return(gotten)
    except gen.TimeoutError:
      raise LockTimeout("Failed to acquire lock on %s after "
                        "%s seconds" % (self.path, timeout))

  def _watch_session(self, state):
    self.wake_event.set()
    return True

  def _watch_session_listener(self, state):
    IOLoop.current().add_callback(self._watch_session, state)

  @gen.coroutine
  def _inner_acquire(self, timeout, ephemeral=True):

    # wait until it's our chance to get it..
    if self.is_acquired:
      raise ForceRetryError()

    # make sure our election parent node exists
    if not self.assured_path:
      yield self._ensure_path()

    node = None
    if self.create_tried:
      node = yield self._find_node()
    else:
      self.create_tried = True

    if not node:
      node = yield self.tornado_kazoo.create(
        self.create_path, self.data, ephemeral=ephemeral, sequence=True)
      # strip off path to node
      node = node[len(self.path) + 1:]

    self.node = node

    while True:
      self.wake_event.clear()

      # bail out with an exception if cancellation has been requested
      if self.cancelled:
        raise CancelledError()

      children = yield self._get_sorted_children()

      try:
        our_index = children.index(node)
      except ValueError:  # pragma: nocover
        # somehow we aren't in the children -- probably we are
        # recovering from a session failure and our ephemeral
        # node was removed
        raise ForceRetryError()

      predecessor = self.predecessor(children, our_index)
      if not predecessor:
        raise gen.Return(True)

      # otherwise we are in the mix. watch predecessor and bide our time
      predecessor = self.path + "/" + predecessor
      self.client.add_listener(self._watch_session_listener)
      try:
        yield self.tornado_kazoo.get(predecessor, self._watch_predecessor)
      except NoNodeError:
        pass  # predecessor has already been deleted
      else:
        try:
          yield self.wake_event.wait(timeout)
        except gen.TimeoutError:
          raise LockTimeout("Failed to acquire lock on %s after "
                            "%s seconds" % (self.path, timeout))
      finally:
        self.client.remove_listener(self._watch_session_listener)

  def predecessor(self, children, index):
    for c in reversed(children[:index]):
      if any(n in c for n in self._EXCLUDE_NAMES):
        return c
    return None

  def _watch_predecessor(self, event):
    self.wake_event.set()

  @gen.coroutine
  def _get_sorted_children(self):
    children = yield self.tornado_kazoo.get_children(self.path)

    # Node names are prefixed by a type: strip the prefix first, which may
    # be one of multiple values in case of a read-write lock, and return
    # only the sequence number (as a string since it is padded and will
    # sort correctly anyway).
    #
    # In some cases, the lock path may contain nodes with other prefixes
    # (eg. in case of a lease), just sort them last ('~' sorts after all
    # ASCII digits).
    def _seq(c):
      for name in ["__lock__", "__rlock__"]:
        idx = c.find(name)
        if idx != -1:
          return c[idx + len(name):]
      # Sort unknown node names eg. "lease_holder" last.
      return '~'

    children.sort(key=_seq)
    raise gen.Return(children)

  @gen.coroutine
  def _find_node(self):
    children = yield self.tornado_kazoo.get_children(self.path)
    for child in children:
      if child.startswith(self.prefix):
        raise gen.Return(child)
    raise gen.Return(None)

  @gen.coroutine
  def _delete_node(self, node):
    yield self.tornado_kazoo.delete(self.path + "/" + node)

  @gen.coroutine
  def _best_effort_cleanup(self):
    try:
      node = self.node
      if not node:
        node = yield self._find_node()
      if node:
        yield self._delete_node(node)
    except KazooException:  # pragma: nocover
      pass

  @gen.coroutine
  def release(self):
    """Release the lock immediately."""
    retry = self._retry.copy()
    release_response = yield retry(self._inner_release)
    raise gen.Return(release_response)

  @gen.coroutine
  def _inner_release(self):
    if not self.is_acquired:
      raise gen.Return(False)

    try:
      yield self._delete_node(self.node)
    except NoNodeError:  # pragma: nocover
      pass

    self.is_acquired = False
    self.node = None
    raise gen.Return(True)

  @gen.coroutine
  def contenders(self):
    """ Returns an ordered list of the current contenders for the lock. """
    # make sure our election parent node exists
    if not self.assured_path:
      yield self._ensure_path()

    children = yield self._get_sorted_children()

    contenders = []
    for child in children:
      try:
        data = yield self.tornado_kazoo.get(self.path + "/" + child)[0]
        contenders.append(data.decode('utf-8'))
      except NoNodeError:  # pragma: nocover
        pass
    raise gen.Return(contenders)
コード例 #48
0
ファイル: httpserver_test.py プロジェクト: bdarnell/tornado
class KeepAliveTest(AsyncHTTPTestCase):
    """Tests various scenarios for HTTP 1.1 keep-alive support.

    These tests don't use AsyncHTTPClient because we want to control
    connection reuse and closing.
    """

    def get_app(self):
        class HelloHandler(RequestHandler):
            def get(self):
                self.finish("Hello world")

            def post(self):
                self.finish("Hello world")

        class LargeHandler(RequestHandler):
            def get(self):
                # 512KB should be bigger than the socket buffers so it will
                # be written out in chunks.
                self.write("".join(chr(i % 256) * 1024 for i in range(512)))

        class FinishOnCloseHandler(RequestHandler):
            def initialize(self, cleanup_event):
                self.cleanup_event = cleanup_event

            @gen.coroutine
            def get(self):
                self.flush()
                yield self.cleanup_event.wait()

            def on_connection_close(self):
                # This is not very realistic, but finishing the request
                # from the close callback has the right timing to mimic
                # some errors seen in the wild.
                self.finish("closed")

        self.cleanup_event = Event()
        return Application(
            [
                ("/", HelloHandler),
                ("/large", LargeHandler),
                (
                    "/finish_on_close",
                    FinishOnCloseHandler,
                    dict(cleanup_event=self.cleanup_event),
                ),
            ]
        )

    def setUp(self):
        super(KeepAliveTest, self).setUp()
        self.http_version = b"HTTP/1.1"

    def tearDown(self):
        # We just closed the client side of the socket; let the IOLoop run
        # once to make sure the server side got the message.
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

        if hasattr(self, "stream"):
            self.stream.close()
        super(KeepAliveTest, self).tearDown()

    # The next few methods are a crude manual http client
    @gen.coroutine
    def connect(self):
        self.stream = IOStream(socket.socket())
        yield self.stream.connect(("127.0.0.1", self.get_http_port()))

    @gen.coroutine
    def read_headers(self):
        first_line = yield self.stream.read_until(b"\r\n")
        self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
        header_bytes = yield self.stream.read_until(b"\r\n\r\n")
        headers = HTTPHeaders.parse(header_bytes.decode("latin1"))
        raise gen.Return(headers)

    @gen.coroutine
    def read_response(self):
        self.headers = yield self.read_headers()
        body = yield self.stream.read_bytes(int(self.headers["Content-Length"]))
        self.assertEqual(b"Hello world", body)

    def close(self):
        self.stream.close()
        del self.stream

    @gen_test
    def test_two_requests(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        self.close()

    @gen_test
    def test_request_close(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
        yield self.read_response()
        data = yield self.stream.read_until_close()
        self.assertTrue(not data)
        self.assertEqual(self.headers["Connection"], "close")
        self.close()

    # keepalive is supported for http 1.0 too, but it's opt-in
    @gen_test
    def test_http10(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
        yield self.read_response()
        data = yield self.stream.read_until_close()
        self.assertTrue(not data)
        self.assertTrue("Connection" not in self.headers)
        self.close()

    @gen_test
    def test_http10_keepalive(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()

    @gen_test
    def test_http10_keepalive_extra_crlf(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()

    @gen_test
    def test_pipelined_requests(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        yield self.read_response()
        self.close()

    @gen_test
    def test_pipelined_cancel(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
        # only read once
        yield self.read_response()
        self.close()

    @gen_test
    def test_cancel_during_download(self):
        yield self.connect()
        self.stream.write(b"GET /large HTTP/1.1\r\n\r\n")
        yield self.read_headers()
        yield self.stream.read_bytes(1024)
        self.close()

    @gen_test
    def test_finish_while_closed(self):
        yield self.connect()
        self.stream.write(b"GET /finish_on_close HTTP/1.1\r\n\r\n")
        yield self.read_headers()
        self.close()
        # Let the hanging coroutine clean up after itself
        self.cleanup_event.set()

    @gen_test
    def test_keepalive_chunked(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(
            b"POST / HTTP/1.0\r\n"
            b"Connection: keep-alive\r\n"
            b"Transfer-Encoding: chunked\r\n"
            b"\r\n"
            b"0\r\n"
            b"\r\n"
        )
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()
コード例 #49
0
ファイル: queues.py プロジェクト: zhiyajun11/tornado
class Queue(object):
    """Coordinate producer and consumer coroutines.

    If maxsize is 0 (the default) the queue size is unbounded.
    """
    def __init__(self, maxsize=0):
        if maxsize is None:
            raise TypeError("maxsize can't be None")

        if maxsize < 0:
            raise ValueError("maxsize can't be negative")

        self._maxsize = maxsize
        self._init()
        self._getters = collections.deque([])  # Futures.
        self._putters = collections.deque([])  # Pairs of (item, Future).
        self._unfinished_tasks = 0
        self._finished = Event()
        self._finished.set()

    @property
    def maxsize(self):
        """Number of items allowed in the queue."""
        return self._maxsize

    def qsize(self):
        """Number of items in the queue."""
        return len(self._queue)

    def empty(self):
        return not self._queue

    def full(self):
        if self.maxsize == 0:
            return False
        else:
            return self.qsize() >= self.maxsize

    def put(self, item, timeout=None):
        """Put an item into the queue, perhaps waiting until there is room.

        Returns a Future, which raises `tornado.gen.TimeoutError` after a
        timeout.
        """
        try:
            self.put_nowait(item)
        except QueueFull:
            future = Future()
            self._putters.append((item, future))
            _set_timeout(future, timeout)
            return future
        else:
            return gen._null_future

    def put_nowait(self, item):
        """Put an item into the queue without blocking.

        If no free slot is immediately available, raise `QueueFull`.
        """
        self._consume_expired()
        if self._getters:
            assert self.empty(), "queue non-empty, why are getters waiting?"
            getter = self._getters.popleft()
            self._put(item)
            getter.set_result(self._get())
        elif self.full():
            raise QueueFull
        else:
            self._put(item)

    def get(self, timeout=None):
        """Remove and return an item from the queue.

        Returns a Future which resolves once an item is available, or raises
        `tornado.gen.TimeoutError` after a timeout.
        """
        future = Future()
        try:
            future.set_result(self.get_nowait())
        except QueueEmpty:
            self._getters.append(future)
            _set_timeout(future, timeout)
        return future

    def get_nowait(self):
        """Remove and return an item from the queue without blocking.

        Return an item if one is immediately available, else raise
        `QueueEmpty`.
        """
        self._consume_expired()
        if self._putters:
            assert self.full(), "queue not full, why are putters waiting?"
            item, putter = self._putters.popleft()
            self._put(item)
            putter.set_result(None)
            return self._get()
        elif self.qsize():
            return self._get()
        else:
            raise QueueEmpty

    def task_done(self):
        """Indicate that a formerly enqueued task is complete.

        Used by queue consumers. For each `.get` used to fetch a task, a
        subsequent call to `.task_done` tells the queue that the processing
        on the task is complete.

        If a `.join` is blocking, it resumes when all items have been
        processed; that is, when every `.put` is matched by a `.task_done`.

        Raises `ValueError` if called more times than `.put`.
        """
        if self._unfinished_tasks <= 0:
            raise ValueError('task_done() called too many times')
        self._unfinished_tasks -= 1
        if self._unfinished_tasks == 0:
            self._finished.set()

    def join(self, timeout=None):
        """Block until all items in the queue are processed. Returns a Future.

        Returns a Future, which raises `tornado.gen.TimeoutError` after a
        timeout.
        """
        return self._finished.wait(timeout)

    def _init(self):
        self._queue = collections.deque()

    def _get(self):
        return self._queue.popleft()

    def _put(self, item):
        self._unfinished_tasks += 1
        self._finished.clear()
        self._queue.append(item)

    def _consume_expired(self):
        # Remove timed-out waiters.
        while self._putters and self._putters[0][1].done():
            self._putters.popleft()

        while self._getters and self._getters[0].done():
            self._getters.popleft()

    def __repr__(self):
        return '<%s at %s %s>' % (
            type(self).__name__, hex(id(self)), self._format())

    def __str__(self):
        return '<%s %s>' % (type(self).__name__, self._format())

    def _format(self):
        result = 'maxsize=%r' % (self.maxsize, )
        if getattr(self, '_queue', None):
            result += ' queue=%r' % self._queue
        if self._getters:
            result += ' getters[%s]' % len(self._getters)
        if self._putters:
            result += ' putters[%s]' % len(self._putters)
        if self._unfinished_tasks:
            result += ' tasks=%s' % self._unfinished_tasks
        return result
コード例 #50
0
ファイル: core.py プロジェクト: mariusvniekerk/distributed
class ConnectionPool(object):
    """ A maximum sized pool of Comm objects.

    This provides a connect method that mirrors the normal distributed.connect
    method, but provides connection sharing and tracks connection limits.

    This object provides an ``rpc`` like interface::

        >>> rpc = ConnectionPool(limit=512)
        >>> scheduler = rpc('127.0.0.1:8786')
        >>> workers = [rpc(address) for address ...]

        >>> info = yield scheduler.identity()

    It creates enough comms to satisfy concurrent connections to any
    particular address::

        >>> a, b = yield [scheduler.who_has(), scheduler.has_what()]

    It reuses existing comms so that we don't have to continuously reconnect.

    It also maintains a comm limit to avoid "too many open file handle"
    issues.  Whenever this maximum is reached we clear out all idling comms.
    If that doesn't do the trick then we wait until one of the occupied comms
    closes.

    Parameters
    ----------
    limit: int
        The number of open comms to maintain at once
    deserialize: bool
        Whether or not to deserialize data by default or pass it through
    """
    def __init__(self, limit=512, deserialize=True):
        self.open = 0  # Total number of open comms
        self.active = 0  # Number of comms currently in use
        self.limit = limit  # Max number of open comms
        # Invariant: len(available) == open - active
        self.available = defaultdict(set)
        # Invariant: len(occupied) == active
        self.occupied = defaultdict(set)
        self.deserialize = deserialize
        self.event = Event()

    def __str__(self):
        return "<ConnectionPool: open=%d, active=%d>" % (self.open,
                                                         self.active)

    __repr__ = __str__

    def __call__(self, addr=None, ip=None, port=None):
        """ Cached rpc objects """
        addr = addr_from_args(addr=addr, ip=ip, port=port)
        return PooledRPCCall(addr, self)

    @gen.coroutine
    def connect(self, addr, timeout=3):
        """
        Get a Comm to the given address.  For internal use.
        """
        available = self.available[addr]
        occupied = self.occupied[addr]
        if available:
            comm = available.pop()
            if not comm.closed():
                self.active += 1
                occupied.add(comm)
                raise gen.Return(comm)
            else:
                self.open -= 1

        while self.open >= self.limit:
            self.event.clear()
            self.collect()
            yield self.event.wait()

        self.open += 1
        try:
            comm = yield connect(addr,
                                 timeout=timeout,
                                 deserialize=self.deserialize)
        except Exception:
            self.open -= 1
            raise
        self.active += 1
        occupied.add(comm)

        if self.open >= self.limit:
            self.event.clear()

        raise gen.Return(comm)

    def reuse(self, addr, comm):
        """
        Reuse an open communication to the given address.  For internal use.
        """
        self.occupied[addr].remove(comm)
        self.active -= 1
        if comm.closed():
            self.open -= 1
            if self.open < self.limit:
                self.event.set()
        else:
            self.available[addr].add(comm)

    def collect(self):
        """
        Collect open but unused communications, to allow opening other ones.
        """
        logger.info("Collecting unused comms.  open: %d, active: %d",
                    self.open, self.active)
        for addr, comms in self.available.items():
            for comm in comms:
                comm.close()
            comms.clear()
        self.open = self.active
        if self.open < self.limit:
            self.event.set()

    def close(self):
        """
        Close all communications abruptly.
        """
        for comms in self.available.values():
            for comm in comms:
                comm.abort()
        for comms in self.occupied.values():
            for comm in comms:
                comm.abort()
コード例 #51
0
ファイル: queues.py プロジェクト: fuglede/distributed
class Queue:
    """ Distributed Queue

    This allows multiple clients to share futures or small bits of data between
    each other with a multi-producer/multi-consumer queue.  All metadata is
    sequentialized through the scheduler.

    Elements of the Queue must be either Futures or msgpack-encodable data
    (ints, strings, lists, dicts).  All data is sent through the scheduler so
    it is wise not to send large objects.  To share large objects scatter the
    data and share the future instead.

    .. warning::

       This object is experimental and has known issues in Python 2

    Examples
    --------
    >>> from dask.distributed import Client, Queue  # doctest: +SKIP
    >>> client = Client()  # doctest: +SKIP
    >>> queue = Queue('x')  # doctest: +SKIP
    >>> future = client.submit(f, x)  # doctest: +SKIP
    >>> queue.put(future)  # doctest: +SKIP

    See Also
    --------
    Variable: shared variable between clients
    """

    def __init__(self, name=None, client=None, maxsize=0):
        self.client = client or _get_global_client()
        self.name = name or "queue-" + uuid.uuid4().hex
        self._event_started = Event()
        if self.client.asynchronous or getattr(
            thread_state, "on_event_loop_thread", False
        ):

            async def _create_queue():
                await self.client.scheduler.queue_create(
                    name=self.name, maxsize=maxsize
                )
                self._event_started.set()

            self.client.loop.add_callback(_create_queue)
        else:
            sync(
                self.client.loop,
                self.client.scheduler.queue_create,
                name=self.name,
                maxsize=maxsize,
            )
            self._event_started.set()

    def __await__(self):
        async def _():
            await self._event_started.wait()
            return self

        return _().__await__()

    async def _put(self, value, timeout=None):
        if isinstance(value, Future):
            await self.client.scheduler.queue_put(
                key=tokey(value.key), timeout=timeout, name=self.name
            )
        else:
            await self.client.scheduler.queue_put(
                data=value, timeout=timeout, name=self.name
            )

    def put(self, value, timeout=None, **kwargs):
        """ Put data into the queue """
        return self.client.sync(self._put, value, timeout=timeout, **kwargs)

    def get(self, timeout=None, batch=False, **kwargs):
        """ Get data from the queue

        Parameters
        ----------
        timeout: Number (optional)
            Time in seconds to wait before timing out
        batch: boolean, int (optional)
            If True then return all elements currently waiting in the queue.
            If an integer than return that many elements from the queue
            If False (default) then return one item at a time
         """
        return self.client.sync(self._get, timeout=timeout, batch=batch, **kwargs)

    def qsize(self, **kwargs):
        """ Current number of elements in the queue """
        return self.client.sync(self._qsize, **kwargs)

    async def _get(self, timeout=None, batch=False):
        try:
            resp = await self.client.scheduler.queue_get(
                timeout=timeout, name=self.name, batch=batch
            )
        except gen.TimeoutError:
            raise TimeoutError("Timed out waiting for Queue")

        def process(d):
            if d["type"] == "Future":
                value = Future(d["value"], self.client, inform=True, state=d["state"])
                if d["state"] == "erred":
                    value._state.set_error(d["exception"], d["traceback"])
                self.client._send_to_scheduler(
                    {"op": "queue-future-release", "name": self.name, "key": d["value"]}
                )
            else:
                value = d["value"]

            return value

        if batch is False:
            result = process(resp)
        else:
            result = list(map(process, resp))

        return result

    async def _qsize(self):
        result = await self.client.scheduler.queue_qsize(name=self.name)
        return result

    def close(self):
        if self.client.status == "running":  # TODO: can leave zombie futures
            self.client._send_to_scheduler({"op": "queue_release", "name": self.name})

    def __getstate__(self):
        return (self.name, self.client.scheduler.address)

    def __setstate__(self, state):
        name, address = state
        try:
            client = get_client(address)
            assert client.scheduler.address == address
        except (AttributeError, AssertionError):
            client = Client(address, set_as_default=False)
        self.__init__(name=name, client=client)
コード例 #52
0
ファイル: nanny.py プロジェクト: jochen-ott-by/distributed
class WorkerProcess(object):
    def __init__(
        self,
        worker_kwargs,
        worker_start_args,
        silence_logs,
        on_exit,
        worker,
        env,
        config,
    ):
        self.status = "init"
        self.silence_logs = silence_logs
        self.worker_kwargs = worker_kwargs
        self.worker_start_args = worker_start_args
        self.on_exit = on_exit
        self.process = None
        self.Worker = worker
        self.env = env
        self.config = config

        # Initialized when worker is ready
        self.worker_dir = None
        self.worker_address = None

    async def start(self):
        """
        Ensure the worker process is started.
        """
        enable_proctitle_on_children()
        if self.status == "running":
            return self.status
        if self.status == "starting":
            await self.running.wait()
            return self.status

        self.init_result_q = init_q = mp_context.Queue()
        self.child_stop_q = mp_context.Queue()
        uid = uuid.uuid4().hex

        self.process = AsyncProcess(
            target=self._run,
            name="Dask Worker process (from Nanny)",
            kwargs=dict(
                worker_kwargs=self.worker_kwargs,
                worker_start_args=self.worker_start_args,
                silence_logs=self.silence_logs,
                init_result_q=self.init_result_q,
                child_stop_q=self.child_stop_q,
                uid=uid,
                Worker=self.Worker,
                env=self.env,
                config=self.config,
            ),
        )
        self.process.daemon = dask.config.get("distributed.worker.daemon", default=True)
        self.process.set_exit_callback(self._on_exit)
        self.running = Event()
        self.stopped = Event()
        self.status = "starting"
        try:
            await self.process.start()
        except OSError:
            logger.exception("Nanny failed to start process", exc_info=True)
            self.process.terminate()
            return

        msg = await self._wait_until_connected(uid)
        if not msg:
            return self.status
        self.worker_address = msg["address"]
        self.worker_dir = msg["dir"]
        assert self.worker_address
        self.status = "running"
        self.running.set()

        init_q.close()

        return self.status

    def _on_exit(self, proc):
        if proc is not self.process:
            # Ignore exit of old process instance
            return
        self.mark_stopped()

    def _death_message(self, pid, exitcode):
        assert exitcode is not None
        if exitcode == 255:
            return "Worker process %d was killed by unknown signal" % (pid,)
        elif exitcode >= 0:
            return "Worker process %d exited with status %d" % (pid, exitcode)
        else:
            return "Worker process %d was killed by signal %d" % (pid, -exitcode)

    def is_alive(self):
        return self.process is not None and self.process.is_alive()

    @property
    def pid(self):
        return self.process.pid if self.process and self.process.is_alive() else None

    def mark_stopped(self):
        if self.status != "stopped":
            r = self.process.exitcode
            assert r is not None
            if r != 0:
                msg = self._death_message(self.process.pid, r)
                logger.info(msg)
            self.status = "stopped"
            self.stopped.set()
            # Release resources
            self.process.close()
            self.init_result_q = None
            self.child_stop_q = None
            self.process = None
            # Best effort to clean up worker directory
            if self.worker_dir and os.path.exists(self.worker_dir):
                shutil.rmtree(self.worker_dir, ignore_errors=True)
            self.worker_dir = None
            # User hook
            if self.on_exit is not None:
                self.on_exit(r)

    async def kill(self, timeout=2, executor_wait=True):
        """
        Ensure the worker process is stopped, waiting at most
        *timeout* seconds before terminating it abruptly.
        """
        loop = IOLoop.current()
        deadline = loop.time() + timeout

        if self.status == "stopped":
            return
        if self.status == "stopping":
            await self.stopped.wait()
            return
        assert self.status in ("starting", "running")
        self.status = "stopping"

        process = self.process
        self.child_stop_q.put(
            {
                "op": "stop",
                "timeout": max(0, deadline - loop.time()) * 0.8,
                "executor_wait": executor_wait,
            }
        )
        self.child_stop_q.close()

        while process.is_alive() and loop.time() < deadline:
            await asyncio.sleep(0.05)

        if process.is_alive():
            logger.warning(
                "Worker process still alive after %d seconds, killing", timeout
            )
            try:
                await process.terminate()
            except Exception as e:
                logger.error("Failed to kill worker process: %s", e)

    async def _wait_until_connected(self, uid):
        delay = 0.05
        while True:
            if self.status != "starting":
                return
            try:
                msg = self.init_result_q.get_nowait()
            except Empty:
                await asyncio.sleep(delay)
                continue

            if msg["uid"] != uid:  # ensure that we didn't cross queues
                continue

            if "exception" in msg:
                logger.error(
                    "Failed while trying to start worker process: %s", msg["exception"]
                )
                await self.process.join()
                raise msg
            else:
                return msg

    @classmethod
    def _run(
        cls,
        worker_kwargs,
        worker_start_args,
        silence_logs,
        init_result_q,
        child_stop_q,
        uid,
        env,
        config,
        Worker,
    ):  # pragma: no cover
        os.environ.update(env)
        dask.config.set(config)
        try:
            from dask.multiprocessing import initialize_worker_process
        except ImportError:  # old Dask version
            pass
        else:
            initialize_worker_process()

        if silence_logs:
            logger.setLevel(silence_logs)

        IOLoop.clear_instance()
        loop = IOLoop()
        loop.make_current()
        worker = Worker(**worker_kwargs)

        async def do_stop(timeout=5, executor_wait=True):
            try:
                await worker.close(
                    report=False,
                    nanny=False,
                    executor_wait=executor_wait,
                    timeout=timeout,
                )
            finally:
                loop.stop()

        def watch_stop_q():
            """
            Wait for an incoming stop message and then stop the
            worker cleanly.
            """
            while True:
                try:
                    msg = child_stop_q.get(timeout=1000)
                except Empty:
                    pass
                else:
                    child_stop_q.close()
                    assert msg.pop("op") == "stop"
                    loop.add_callback(do_stop, **msg)
                    break

        t = threading.Thread(target=watch_stop_q, name="Nanny stop queue watch")
        t.daemon = True
        t.start()

        async def run():
            """
            Try to start worker and inform parent of outcome.
            """
            try:
                await worker
            except Exception as e:
                logger.exception("Failed to start worker")
                init_result_q.put({"uid": uid, "exception": e})
                init_result_q.close()
            else:
                try:
                    assert worker.address
                except ValueError:
                    pass
                else:
                    init_result_q.put(
                        {
                            "address": worker.address,
                            "dir": worker.local_directory,
                            "uid": uid,
                        }
                    )
                    init_result_q.close()
                    await worker.finished()
                    logger.info("Worker closed")

        try:
            loop.run_sync(run)
        except TimeoutError:
            # Loop was stopped before wait_until_closed() returned, ignore
            pass
        except KeyboardInterrupt:
            pass
コード例 #53
0
ファイル: async_task_manager.py プロジェクト: rydzykje/aucote
class AsyncTaskManager(object):
    """
    Aucote uses asynchronous task executed in ioloop. Some of them,
    especially scanners, should finish before ioloop will stop

    This class should be accessed by instance class method, which returns global instance of task manager

    """
    _instances = {}

    TASKS_POLITIC_WAIT = 0
    TASKS_POLITIC_KILL_WORKING_FIRST = 1
    TASKS_POLITIC_KILL_PROPORTIONS = 2
    TASKS_POLITIC_KILL_WORKING = 3

    def __init__(self, parallel_tasks=10):
        self._shutdown_condition = Event()
        self._stop_condition = Event()
        self._cron_tasks = {}
        self._parallel_tasks = parallel_tasks
        self._tasks = Queue()
        self._task_workers = {}
        self._events = {}
        self._limit = self._parallel_tasks
        self._next_task_number = 0
        self._toucan_keys = {}

    @classmethod
    def instance(cls, name=None, **kwargs):
        """
        Return instance of AsyncTaskManager

        Returns:
            AsyncTaskManager

        """
        if cls._instances.get(name) is None:
            cls._instances[name] = AsyncTaskManager(**kwargs)
        return cls._instances[name]

    @property
    def shutdown_condition(self):
        """
        Event which is resolved if every job is done and AsyncTaskManager is ready to shutdown

        Returns:
            Event
        """
        return self._shutdown_condition

    def start(self):
        """
        Start CronTabCallback tasks

        Returns:
            None

        """
        for task in self._cron_tasks.values():
            task.start()

        for number in range(self._parallel_tasks):
            self._task_workers[number] = IOLoop.current().add_callback(
                partial(self.process_tasks, number))

        self._next_task_number = self._parallel_tasks

    def add_crontab_task(self, task, cron, event=None):
        """
        Add function to scheduler and execute at cron time

        Args:
            task (function):
            cron (str): crontab value
            event (Event): event which prevent from running task with similar aim, eg. security scans

        Returns:
            None

        """
        if event is not None:
            event = self._events.setdefault(event, Event())
        self._cron_tasks[task] = AsyncCrontabTask(cron, task, event)

    @gen.coroutine
    def stop(self):
        """
        Stop CronTabCallback tasks and wait on them to finish

        Returns:
            None

        """
        for task in self._cron_tasks.values():
            task.stop()
        IOLoop.current().add_callback(self._prepare_shutdown)
        yield [self._stop_condition.wait(), self._tasks.join()]
        self._shutdown_condition.set()

    def _prepare_shutdown(self):
        """
        Check if ioloop can be stopped

        Returns:
            None

        """
        if any(task.is_running() for task in self._cron_tasks.values()):
            IOLoop.current().add_callback(self._prepare_shutdown)
            return

        self._stop_condition.set()

    def clear(self):
        """
        Clear list of tasks

        Returns:
            None

        """
        self._cron_tasks = {}
        self._shutdown_condition.clear()
        self._stop_condition.clear()

    async def process_tasks(self, number):
        """
        Execute queue. Every task in executed in separated thread (_Executor)

        """
        log.info("Starting worker %s", number)
        while True:
            try:
                item = self._tasks.get_nowait()
                try:
                    log.debug("Worker %s: starting %s", number, item)
                    thread = _Executor(task=item, number=number)
                    self._task_workers[number] = thread
                    thread.start()

                    while thread.is_alive():
                        await sleep(0.5)
                except:
                    log.exception("Worker %s: exception occurred", number)
                finally:
                    log.debug("Worker %s: %s finished", number, item)
                    self._tasks.task_done()
                    tasks_per_scan = (
                        '{}: {}'.format(scanner, len(tasks))
                        for scanner, tasks in self.tasks_by_scan.items())
                    log.debug("Tasks left in queue: %s (%s)",
                              self.unfinished_tasks, ', '.join(tasks_per_scan))
                    self._task_workers[number] = None
            except QueueEmpty:
                await gen.sleep(0.5)
                if self._stop_condition.is_set() and self._tasks.empty():
                    return
            finally:
                if self._limit < len(self._task_workers):
                    break

        del self._task_workers[number]

        log.info("Closing worker %s", number)

    def add_task(self, task):
        """
        Add task to the queue

        Args:
            task:

        Returns:
            None

        """
        self._tasks.put(task)

    @property
    def unfinished_tasks(self):
        """
        Task which are still processed or in queue

        Returns:
            int

        """
        return self._tasks._unfinished_tasks

    @property
    def tasks_by_scan(self):
        """
        Returns queued tasks grouped by scan
        """
        tasks = self._tasks._queue

        return_value = {}

        for task in tasks:
            return_value.setdefault(task.context.scanner.NAME, []).append(task)

        return return_value

    @property
    def cron_tasks(self):
        """
        List of cron tasks

        Returns:
            list

        """
        return self._cron_tasks.values()

    def cron_task(self, name):
        for task in self._cron_tasks.values():
            if task.func.NAME == name:
                return task

    def change_throttling_toucan(self, key, value):
        self.change_throttling(value)

    def change_throttling(self, new_value):
        """
        Change throttling value. Keeps throttling value between 0 and 1.

        Behaviour of algorithm is described in docs/throttling.md

        Only working tasks are closing here. Idle workers are stop by themselves

        """
        if new_value > 1:
            new_value = 1
        if new_value < 0:
            new_value = 0

        new_value = round(new_value * 100) / 100

        old_limit = self._limit
        self._limit = round(self._parallel_tasks * float(new_value))

        working_tasks = [
            number for number, task in self._task_workers.items()
            if task is not None
        ]
        current_tasks = len(self._task_workers)

        task_politic = cfg['service.scans.task_politic']

        if task_politic == self.TASKS_POLITIC_KILL_WORKING_FIRST:
            tasks_to_kill = current_tasks - self._limit
        elif task_politic == self.TASKS_POLITIC_KILL_PROPORTIONS:
            tasks_to_kill = round((old_limit - self._limit) *
                                  len(working_tasks) / self._parallel_tasks)
        elif task_politic == self.TASKS_POLITIC_KILL_WORKING:
            tasks_to_kill = (old_limit - self._limit) - (
                len(self._task_workers) - len(working_tasks))
        else:
            tasks_to_kill = 0

        log.debug('%s tasks will be killed', tasks_to_kill)

        for number in working_tasks:
            if tasks_to_kill <= 0:
                break
            self._task_workers[number].stop()
            tasks_to_kill -= 1

        self._limit = round(self._parallel_tasks * float(new_value))

        current_tasks = len(self._task_workers)

        for number in range(self._limit - current_tasks):
            self._task_workers[self._next_task_number] = None
            IOLoop.current().add_callback(
                partial(self.process_tasks, self._next_task_number))
            self._next_task_number += 1
コード例 #54
0
class Server(object):
    """ Dask Distributed Server

    Superclass for endpoints in a distributed cluster, such as Worker
    and Scheduler objects.

    **Handlers**

    Servers define operations with a ``handlers`` dict mapping operation names
    to functions.  The first argument of a handler function will be a ``Comm``
    for the communication established with the client.  Other arguments
    will receive inputs from the keys of the incoming message which will
    always be a dictionary.

    >>> def pingpong(comm):
    ...     return b'pong'

    >>> def add(comm, x, y):
    ...     return x + y

    >>> handlers = {'ping': pingpong, 'add': add}
    >>> server = Server(handlers)  # doctest: +SKIP
    >>> server.listen('tcp://0.0.0.0:8000')  # doctest: +SKIP

    **Message Format**

    The server expects messages to be dictionaries with a special key, `'op'`
    that corresponds to the name of the operation, and other key-value pairs as
    required by the function.

    So in the example above the following would be good messages.

    *  ``{'op': 'ping'}``
    *  ``{'op': 'add', 'x': 10, 'y': 20}``

    """

    default_ip = ""
    default_port = 0

    def __init__(
        self,
        handlers,
        blocked_handlers=None,
        stream_handlers=None,
        connection_limit=512,
        deserialize=True,
        io_loop=None,
    ):
        self.handlers = {
            "identity": self.identity,
            "connection_stream": self.handle_stream,
        }
        self.handlers.update(handlers)
        if blocked_handlers is None:
            blocked_handlers = dask.config.get(
                "distributed.%s.blocked-handlers" %
                type(self).__name__.lower(), [])
        self.blocked_handlers = blocked_handlers
        self.stream_handlers = {}
        self.stream_handlers.update(stream_handlers or {})

        self.id = type(self).__name__ + "-" + str(uuid.uuid4())
        self._address = None
        self._listen_address = None
        self._port = None
        self._comms = {}
        self.deserialize = deserialize
        self.monitor = SystemMonitor()
        self.counters = None
        self.digests = None
        self.events = None
        self.event_counts = None
        self._ongoing_coroutines = weakref.WeakSet()
        self._event_finished = Event()

        self.listener = None
        self.io_loop = io_loop or IOLoop.current()
        self.loop = self.io_loop

        if not hasattr(self.io_loop, "profile"):
            ref = weakref.ref(self.io_loop)

            if hasattr(self.io_loop, "asyncio_loop"):

                def stop():
                    loop = ref()
                    return loop is None or loop.asyncio_loop.is_closed()

            else:

                def stop():
                    loop = ref()
                    return loop is None or loop._closing

            self.io_loop.profile = profile.watch(
                omit=("profile.py", "selectors.py"),
                interval=dask.config.get(
                    "distributed.worker.profile.interval"),
                cycle=dask.config.get("distributed.worker.profile.cycle"),
                stop=stop,
            )

        # Statistics counters for various events
        with ignoring(ImportError):
            from .counter import Digest

            self.digests = defaultdict(partial(Digest, loop=self.io_loop))

        from .counter import Counter

        self.counters = defaultdict(partial(Counter, loop=self.io_loop))
        self.events = defaultdict(lambda: deque(maxlen=10000))
        self.event_counts = defaultdict(lambda: 0)

        self.periodic_callbacks = dict()

        pc = PeriodicCallback(self.monitor.update, 500, io_loop=self.io_loop)
        self.periodic_callbacks["monitor"] = pc

        self._last_tick = time()
        pc = PeriodicCallback(
            self._measure_tick,
            parse_timedelta(dask.config.get("distributed.admin.tick.interval"),
                            default="ms") * 1000,
            io_loop=self.io_loop,
        )
        self.periodic_callbacks["tick"] = pc

        self.thread_id = 0

        def set_thread_ident():
            self.thread_id = threading.get_ident()

        self.io_loop.add_callback(set_thread_ident)

        self.__stopped = False

    async def finished(self):
        """ Wait until the server has finished """
        await self._event_finished.wait()

    def start_periodic_callbacks(self):
        """ Start Periodic Callbacks consistently

        This starts all PeriodicCallbacks stored in self.periodic_callbacks if
        they are not yet running.  It does this safely on the IOLoop.
        """
        self._last_tick = time()

        def start_pcs():
            for pc in self.periodic_callbacks.values():
                if not pc.is_running():
                    pc.start()

        self.io_loop.add_callback(start_pcs)

    def stop(self):
        if not self.__stopped:
            self.__stopped = True
            if self.listener is not None:
                # Delay closing the server socket until the next IO loop tick.
                # Otherwise race conditions can appear if an event handler
                # for an accept() call is already scheduled by the IO loop,
                # raising EBADF.
                # The demonstrator for this is Worker.terminate(), which
                # closes the server socket in response to an incoming message.
                # See https://github.com/tornadoweb/tornado/issues/2069
                self.io_loop.add_callback(self.listener.stop)

    def _measure_tick(self):
        now = time()
        diff = now - self._last_tick
        self._last_tick = now
        if diff > tick_maximum_delay:
            logger.info(
                "Event loop was unresponsive in %s for %.2fs.  "
                "This is often caused by long-running GIL-holding "
                "functions or moving large chunks of data. "
                "This can cause timeouts and instability.",
                type(self).__name__,
                diff,
            )
        if self.digests is not None:
            self.digests["tick-duration"].add(diff)

    def log_event(self, name, msg):
        msg["time"] = time()
        if isinstance(name, list):
            for n in name:
                self.events[n].append(msg)
                self.event_counts[n] += 1
        else:
            self.events[name].append(msg)
            self.event_counts[name] += 1

    @property
    def address(self):
        """
        The address this Server can be contacted on.
        """
        if not self._address:
            if self.listener is None:
                raise ValueError("cannot get address of non-running Server")
            self._address = self.listener.contact_address
        return self._address

    @property
    def listen_address(self):
        """
        The address this Server is listening on.  This may be a wildcard
        address such as `tcp://0.0.0.0:1234`.
        """
        if not self._listen_address:
            if self.listener is None:
                raise ValueError(
                    "cannot get listen address of non-running Server")
            self._listen_address = self.listener.listen_address
        return self._listen_address

    @property
    def port(self):
        """
        The port number this Server is listening on.

        This will raise ValueError if the Server is listening on a
        non-IP based protocol.
        """
        if not self._port:
            _, self._port = get_address_host_port(self.address)
        return self._port

    def identity(self, comm=None):
        return {"type": type(self).__name__, "id": self.id}

    def listen(self, port_or_addr=None, listen_args=None):
        if port_or_addr is None:
            port_or_addr = self.default_port
        if isinstance(port_or_addr, int):
            addr = unparse_host_port(self.default_ip, port_or_addr)
        elif isinstance(port_or_addr, tuple):
            addr = unparse_host_port(*port_or_addr)
        else:
            addr = port_or_addr
            assert isinstance(addr, str)
        self.listener = listen(
            addr,
            self.handle_comm,
            deserialize=self.deserialize,
            connection_args=listen_args,
        )
        self.listener.start()

    async def handle_comm(self, comm, shutting_down=shutting_down):
        """ Dispatch new communications to coroutine-handlers

        Handlers is a dictionary mapping operation names to functions or
        coroutines.

            {'get_data': get_data,
             'ping': pingpong}

        Coroutines should expect a single Comm object.
        """
        if self.__stopped:
            comm.abort()
            return
        address = comm.peer_address
        op = None

        logger.debug("Connection from %r to %s", address, type(self).__name__)
        self._comms[comm] = op
        try:
            while True:
                try:
                    msg = await comm.read()
                    logger.debug("Message from %r: %s", address, msg)
                except EnvironmentError as e:
                    if not shutting_down():
                        logger.debug(
                            "Lost connection to %r while reading message: %s."
                            " Last operation: %s",
                            address,
                            e,
                            op,
                        )
                    break
                except Exception as e:
                    logger.exception(e)
                    await comm.write(error_message(e, status="uncaught-error"))
                    continue
                if not isinstance(msg, dict):
                    raise TypeError(
                        "Bad message type.  Expected dict, got\n  " + str(msg))

                try:
                    op = msg.pop("op")
                except KeyError:
                    raise ValueError(
                        "Received unexpected message without 'op' key: " +
                        str(msg))
                if self.counters is not None:
                    self.counters["op"].add(op)
                self._comms[comm] = op
                serializers = msg.pop("serializers", None)
                close_desired = msg.pop("close", False)
                reply = msg.pop("reply", True)
                if op == "close":
                    if reply:
                        await comm.write("OK")
                    break

                result = None
                try:
                    if op in self.blocked_handlers:
                        _msg = (
                            "The '{op}' handler has been explicitly disallowed "
                            "in {obj}, possibly due to security concerns.")
                        exc = ValueError(
                            _msg.format(op=op, obj=type(self).__name__))
                        handler = raise_later(exc)
                    else:
                        handler = self.handlers[op]
                except KeyError:
                    logger.warning(
                        "No handler %s found in %s",
                        op,
                        type(self).__name__,
                        exc_info=True,
                    )
                else:
                    if serializers is not None and has_keyword(
                            handler, "serializers"):
                        msg["serializers"] = serializers  # add back in

                    logger.debug("Calling into handler %s", handler.__name__)
                    try:
                        result = handler(comm, **msg)
                        if hasattr(result, "__await__"):
                            result = asyncio.ensure_future(result)
                            self._ongoing_coroutines.add(result)
                            result = await result
                    except (CommClosedError, CancelledError) as e:
                        if self.status == "running":
                            logger.info("Lost connection to %r: %s", address,
                                        e)
                        break
                    except Exception as e:
                        logger.exception(e)
                        result = error_message(e, status="uncaught-error")

                if reply and result != "dont-reply":
                    try:
                        await comm.write(result, serializers=serializers)
                    except (EnvironmentError, TypeError) as e:
                        logger.debug(
                            "Lost connection to %r while sending result for op %r: %s",
                            address,
                            op,
                            e,
                        )
                        break
                msg = result = None
                if close_desired:
                    await comm.close()
                if comm.closed():
                    break

        finally:
            del self._comms[comm]
            if not shutting_down() and not comm.closed():
                try:
                    comm.abort()
                except Exception as e:
                    logger.error("Failed while closing connection to %r: %s",
                                 address, e)

    async def handle_stream(self, comm, extra=None, every_cycle=[]):
        extra = extra or {}
        logger.info("Starting established connection")

        io_error = None
        closed = False
        try:
            while not closed:
                msgs = await comm.read()
                if not isinstance(msgs, (tuple, list)):
                    msgs = (msgs, )

                if not comm.closed():
                    for msg in msgs:
                        if msg == "OK":  # from close
                            break
                        op = msg.pop("op")
                        if op:
                            if op == "close-stream":
                                closed = True
                                break
                            handler = self.stream_handlers[op]
                            if is_coroutine_function(handler):
                                self.loop.add_callback(handler,
                                                       **merge(extra, msg))
                            else:
                                handler(**merge(extra, msg))
                        else:
                            logger.error("odd message %s", msg)
                    await gen.sleep(0)

                for func in every_cycle:
                    func()

        except (CommClosedError, EnvironmentError) as e:
            io_error = e
        except Exception as e:
            logger.exception(e)
            if LOG_PDB:
                import pdb

                pdb.set_trace()
            raise
        finally:
            await comm.close()
            assert comm.closed()

    @gen.coroutine
    def close(self):
        for pc in self.periodic_callbacks.values():
            pc.stop()
        if self.listener:
            self.listener.stop()
        for i in range(20):  # let comms close naturally for a second
            if not self._comms:
                break
            else:
                yield gen.sleep(0.05)
        yield [comm.close() for comm in self._comms]  # then forcefully close
        for cb in self._ongoing_coroutines:
            cb.cancel()
        for i in range(10):
            if all(cb.cancelled() for c in self._ongoing_coroutines):
                break
            else:
                yield gen.sleep(0.01)

        self._event_finished.set()
コード例 #55
0
ファイル: queues.py プロジェクト: tantona/tornado
class Queue(object):
    """Coordinate producer and consumer coroutines.

    If maxsize is 0 (the default) the queue size is unbounded.
    """
    def __init__(self, maxsize=0):
        if maxsize is None:
            raise TypeError("maxsize can't be None")

        if maxsize < 0:
            raise ValueError("maxsize can't be negative")

        self._maxsize = maxsize
        self._init()
        self._getters = collections.deque([])  # Futures.
        self._putters = collections.deque([])  # Pairs of (item, Future).
        self._unfinished_tasks = 0
        self._finished = Event()
        self._finished.set()

    @property
    def maxsize(self):
        """Number of items allowed in the queue."""
        return self._maxsize

    def qsize(self):
        """Number of items in the queue."""
        return len(self._queue)

    def empty(self):
        return not self._queue

    def full(self):
        if self.maxsize == 0:
            return False
        else:
            return self.qsize() >= self.maxsize

    def put(self, item, timeout=None):
        """Put an item into the queue, perhaps waiting until there is room.

        Returns a Future, which raises `tornado.gen.TimeoutError` after a
        timeout.
        """
        try:
            self.put_nowait(item)
        except QueueFull:
            future = Future()
            self._putters.append((item, future))
            _set_timeout(future, timeout)
            return future
        else:
            return gen._null_future

    def put_nowait(self, item):
        """Put an item into the queue without blocking.

        If no free slot is immediately available, raise `QueueFull`.
        """
        self._consume_expired()
        if self._getters:
            assert self.empty(), "queue non-empty, why are getters waiting?"
            getter = self._getters.popleft()
            self.__put_internal(item)
            getter.set_result(self._get())
        elif self.full():
            raise QueueFull
        else:
            self.__put_internal(item)

    def get(self, timeout=None):
        """Remove and return an item from the queue.

        Returns a Future which resolves once an item is available, or raises
        `tornado.gen.TimeoutError` after a timeout.
        """
        future = Future()
        try:
            future.set_result(self.get_nowait())
        except QueueEmpty:
            self._getters.append(future)
            _set_timeout(future, timeout)
        return future

    def get_nowait(self):
        """Remove and return an item from the queue without blocking.

        Return an item if one is immediately available, else raise
        `QueueEmpty`.
        """
        self._consume_expired()
        if self._putters:
            assert self.full(), "queue not full, why are putters waiting?"
            item, putter = self._putters.popleft()
            self.__put_internal(item)
            putter.set_result(None)
            return self._get()
        elif self.qsize():
            return self._get()
        else:
            raise QueueEmpty

    def task_done(self):
        """Indicate that a formerly enqueued task is complete.

        Used by queue consumers. For each `.get` used to fetch a task, a
        subsequent call to `.task_done` tells the queue that the processing
        on the task is complete.

        If a `.join` is blocking, it resumes when all items have been
        processed; that is, when every `.put` is matched by a `.task_done`.

        Raises `ValueError` if called more times than `.put`.
        """
        if self._unfinished_tasks <= 0:
            raise ValueError('task_done() called too many times')
        self._unfinished_tasks -= 1
        if self._unfinished_tasks == 0:
            self._finished.set()

    def join(self, timeout=None):
        """Block until all items in the queue are processed.

        Returns a Future, which raises `tornado.gen.TimeoutError` after a
        timeout.
        """
        return self._finished.wait(timeout)

    # These three are overridable in subclasses.
    def _init(self):
        self._queue = collections.deque()

    def _get(self):
        return self._queue.popleft()

    def _put(self, item):
        self._queue.append(item)

    # End of the overridable methods.

    def __put_internal(self, item):
        self._unfinished_tasks += 1
        self._finished.clear()
        self._put(item)

    def _consume_expired(self):
        # Remove timed-out waiters.
        while self._putters and self._putters[0][1].done():
            self._putters.popleft()

        while self._getters and self._getters[0].done():
            self._getters.popleft()

    def __repr__(self):
        return '<%s at %s %s>' % (type(self).__name__, hex(
            id(self)), self._format())

    def __str__(self):
        return '<%s %s>' % (type(self).__name__, self._format())

    def _format(self):
        result = 'maxsize=%r' % (self.maxsize, )
        if getattr(self, '_queue', None):
            result += ' queue=%r' % self._queue
        if self._getters:
            result += ' getters[%s]' % len(self._getters)
        if self._putters:
            result += ' putters[%s]' % len(self._putters)
        if self._unfinished_tasks:
            result += ' tasks=%s' % self._unfinished_tasks
        return result
コード例 #56
0
class Queue(object):
    """Coordinate producer and consumer coroutines.

    If maxsize is 0 (the default) the queue size is unbounded.

    .. testcode::

        from tornado import gen
        from tornado.ioloop import IOLoop
        from tornado.queues import Queue

        q = Queue(maxsize=2)

        async def consumer():
            async for item in q:
                try:
                    print('Doing work on %s' % item)
                    await gen.sleep(0.01)
                finally:
                    q.task_done()

        async def producer():
            for item in range(5):
                await q.put(item)
                print('Put %s' % item)

        async def main():
            # Start consumer without waiting (since it never finishes).
            IOLoop.current().spawn_callback(consumer)
            await producer()     # Wait for producer to put all tasks.
            await q.join()       # Wait for consumer to finish all tasks.
            print('Done')

        IOLoop.current().run_sync(main)

    .. testoutput::

        Put 0
        Put 1
        Doing work on 0
        Put 2
        Doing work on 1
        Put 3
        Doing work on 2
        Put 4
        Doing work on 3
        Doing work on 4
        Done


    In versions of Python without native coroutines (before 3.5),
    ``consumer()`` could be written as::

        @gen.coroutine
        def consumer():
            while True:
                item = yield q.get()
                try:
                    print('Doing work on %s' % item)
                    yield gen.sleep(0.01)
                finally:
                    q.task_done()

    .. versionchanged:: 4.3
       Added ``async for`` support in Python 3.5.

    """
    def __init__(self, maxsize=0):
        if maxsize is None:
            raise TypeError("maxsize can't be None")

        if maxsize < 0:
            raise ValueError("maxsize can't be negative")

        self._maxsize = maxsize
        self._init()
        self._getters = collections.deque([])  # Futures.
        self._putters = collections.deque([])  # Pairs of (item, Future).
        self._unfinished_tasks = 0
        self._finished = Event()
        self._finished.set()

    @property
    def maxsize(self):
        """Number of items allowed in the queue."""
        return self._maxsize

    def qsize(self):
        """Number of items in the queue."""
        return len(self._queue)

    def empty(self):
        return not self._queue

    def full(self):
        if self.maxsize == 0:
            return False
        else:
            return self.qsize() >= self.maxsize

    def put(self, item, timeout=None):
        """Put an item into the queue, perhaps waiting until there is room.

        Returns a Future, which raises `tornado.util.TimeoutError` after a
        timeout.

        ``timeout`` may be a number denoting a time (on the same
        scale as `tornado.ioloop.IOLoop.time`, normally `time.time`), or a
        `datetime.timedelta` object for a deadline relative to the
        current time.
        """
        future = Future()
        try:
            self.put_nowait(item)
        except QueueFull:
            self._putters.append((item, future))
            _set_timeout(future, timeout)
        else:
            future.set_result(None)
        return future

    def put_nowait(self, item):
        """Put an item into the queue without blocking.

        If no free slot is immediately available, raise `QueueFull`.
        """
        self._consume_expired()
        if self._getters:
            assert self.empty(), "queue non-empty, why are getters waiting?"
            getter = self._getters.popleft()
            self.__put_internal(item)
            future_set_result_unless_cancelled(getter, self._get())
        elif self.full():
            raise QueueFull
        else:
            self.__put_internal(item)

    def get(self, timeout=None):
        """Remove and return an item from the queue.

        Returns a Future which resolves once an item is available, or raises
        `tornado.util.TimeoutError` after a timeout.

        ``timeout`` may be a number denoting a time (on the same
        scale as `tornado.ioloop.IOLoop.time`, normally `time.time`), or a
        `datetime.timedelta` object for a deadline relative to the
        current time.
        """
        future = Future()
        try:
            future.set_result(self.get_nowait())
        except QueueEmpty:
            self._getters.append(future)
            _set_timeout(future, timeout)
        return future

    def get_nowait(self):
        """Remove and return an item from the queue without blocking.

        Return an item if one is immediately available, else raise
        `QueueEmpty`.
        """
        self._consume_expired()
        if self._putters:
            assert self.full(), "queue not full, why are putters waiting?"
            item, putter = self._putters.popleft()
            self.__put_internal(item)
            future_set_result_unless_cancelled(putter, None)
            return self._get()
        elif self.qsize():
            return self._get()
        else:
            raise QueueEmpty

    def task_done(self):
        """Indicate that a formerly enqueued task is complete.

        Used by queue consumers. For each `.get` used to fetch a task, a
        subsequent call to `.task_done` tells the queue that the processing
        on the task is complete.

        If a `.join` is blocking, it resumes when all items have been
        processed; that is, when every `.put` is matched by a `.task_done`.

        Raises `ValueError` if called more times than `.put`.
        """
        if self._unfinished_tasks <= 0:
            raise ValueError('task_done() called too many times')
        self._unfinished_tasks -= 1
        if self._unfinished_tasks == 0:
            self._finished.set()

    def join(self, timeout=None):
        """Block until all items in the queue are processed.

        Returns a Future, which raises `tornado.util.TimeoutError` after a
        timeout.
        """
        return self._finished.wait(timeout)

    def __aiter__(self):
        return _QueueIterator(self)

    # These three are overridable in subclasses.
    def _init(self):
        self._queue = collections.deque()

    def _get(self):
        return self._queue.popleft()

    def _put(self, item):
        self._queue.append(item)

    # End of the overridable methods.

    def __put_internal(self, item):
        self._unfinished_tasks += 1
        self._finished.clear()
        self._put(item)

    def _consume_expired(self):
        # Remove timed-out waiters.
        while self._putters and self._putters[0][1].done():
            self._putters.popleft()

        while self._getters and self._getters[0].done():
            self._getters.popleft()

    def __repr__(self):
        return '<%s at %s %s>' % (type(self).__name__, hex(
            id(self)), self._format())

    def __str__(self):
        return '<%s %s>' % (type(self).__name__, self._format())

    def _format(self):
        result = 'maxsize=%r' % (self.maxsize, )
        if getattr(self, '_queue', None):
            result += ' queue=%r' % self._queue
        if self._getters:
            result += ' getters[%s]' % len(self._getters)
        if self._putters:
            result += ' putters[%s]' % len(self._putters)
        if self._unfinished_tasks:
            result += ' tasks=%s' % self._unfinished_tasks
        return result
コード例 #57
0
class ProjectGroomer(object):
  """ Cleans up expired transactions for a project. """
  def __init__(self, project_id, coordinator, zk_client, db_access,
               thread_pool):
    """ Creates a new ProjectGroomer.

    Args:
      project_id: A string specifying a project ID.
      coordinator: A GroomingCoordinator.
      zk_client: A KazooClient.
      db_access: A DatastoreProxy.
      thread_pool: A ThreadPoolExecutor.
    """
    self.project_id = project_id

    self._coordinator = coordinator
    self._zk_client = zk_client
    self._tornado_zk = TornadoKazoo(self._zk_client)
    self._db_access = db_access
    self._thread_pool = thread_pool
    self._project_node = '/appscale/apps/{}'.format(self.project_id)
    self._containers = []
    self._inactive_containers = set()
    self._batch_resolver = BatchResolver(self.project_id, self._db_access)

    self._zk_client.ensure_path(self._project_node)
    self._zk_client.ChildrenWatch(self._project_node, self._update_containers)

    self._txid_manual_offset = 0
    self._offset_node = '/'.join([self._project_node, OFFSET_NODE])
    self._zk_client.DataWatch(self._offset_node, self._update_offset)

    self._stop_event = AsyncEvent()
    self._stopped_event = AsyncEvent()

    # Keeps track of cleanup results for each round of grooming.
    self._txids_cleaned = 0
    self._oldest_valid_tx_time = None

    self._worker_queue = AsyncQueue(maxsize=MAX_CONCURRENCY)
    for _ in range(MAX_CONCURRENCY):
      IOLoop.current().spawn_callback(self._worker)

    IOLoop.current().spawn_callback(self.start)

  @gen.coroutine
  def start(self):
    """ Starts the grooming process until the stop event is set. """
    logger.info('Grooming {}'.format(self.project_id))
    while True:
      if self._stop_event.is_set():
        break

      try:
        yield self._groom_project()
      except Exception:
        # Prevent the grooming loop from stopping if an error is encountered.
        logger.exception(
          'Unexpected error while grooming {}'.format(self.project_id))
        yield gen.sleep(MAX_TX_DURATION)

    self._stopped_event.set()

  @gen.coroutine
  def stop(self):
    """ Stops the grooming process. """
    logger.info('Stopping grooming process for {}'.format(self.project_id))
    self._stop_event.set()
    yield self._stopped_event.wait()

  @gen.coroutine
  def _worker(self):
    """ Processes items in the worker queue. """
    while True:
      tx_path, composite_indexes = yield self._worker_queue.get()
      try:
        tx_time = yield self._resolve_txid(tx_path, composite_indexes)
        if tx_time is None:
          self._txids_cleaned += 1

        if tx_time is not None and tx_time < self._oldest_valid_tx_time:
          self._oldest_valid_tx_time = tx_time
      finally:
        self._worker_queue.task_done()

  def _update_offset(self, new_offset, _):
    """ Watches for updates to the manual offset node.

    Args:
      new_offset: A string specifying the new manual offset.
    """
    self._txid_manual_offset = int(new_offset or 0)

  def _update_containers(self, nodes):
    """ Updates the list of active txid containers.

    Args:
      nodes: A list of strings specifying ZooKeeper nodes.
    """
    counters = [int(node[len(CONTAINER_PREFIX):] or 1)
                for node in nodes if node.startswith(CONTAINER_PREFIX)
                and node not in self._inactive_containers]
    counters.sort()

    containers = [CONTAINER_PREFIX + str(counter) for counter in counters]
    if containers and containers[0] == '{}1'.format(CONTAINER_PREFIX):
      containers[0] = CONTAINER_PREFIX

    self._containers = containers

  @gen.coroutine
  def _groom_project(self):
    """ Runs the grooming process. """
    index = self._coordinator.index
    worker_count = self._coordinator.total_workers

    oldest_valid_tx_time = yield self._fetch_and_clean(index, worker_count)

    # Wait until there's a reasonable chance that some transactions have
    # timed out.
    next_timeout_eta = oldest_valid_tx_time + MAX_TX_DURATION

    # The oldest ignored transaction should still be valid, but ensure that
    # the timeout is not negative.
    next_timeout = max(0, next_timeout_eta - time.time())
    time_to_wait = datetime.timedelta(
      seconds=next_timeout + (MAX_TX_DURATION / 2))

    # Allow the wait to be cut short when a project is removed.
    try:
      yield self._stop_event.wait(timeout=time_to_wait)
    except gen.TimeoutError:
      raise gen.Return()

  @gen.coroutine
  def _remove_path(self, tx_path):
    """ Removes a ZooKeeper node.

    Args:
      tx_path: A string specifying the path to delete.
    """
    try:
      yield self._tornado_zk.delete(tx_path)
    except NoNodeError:
      pass
    except NotEmptyError:
      yield self._thread_pool.submit(self._zk_client.delete, tx_path,
                                     recursive=True)

  @gen.coroutine
  def _resolve_txid(self, tx_path, composite_indexes):
    """ Cleans up a transaction if it has expired.

    Args:
      tx_path: A string specifying the location of the ZooKeeper node.
      composite_indexes: A list of CompositeIndex objects.
    Returns:
      The transaction start time if still valid, None if invalid because this
      method will also delete it.
    """
    tx_data = yield self._tornado_zk.get(tx_path)
    tx_time = float(tx_data[0])

    _, container, tx_node = tx_path.rsplit('/', 2)
    tx_node_id = int(tx_node.lstrip(COUNTER_NODE_PREFIX))
    container_count = int(container[len(CONTAINER_PREFIX):] or 1)
    if tx_node_id < 0:
      yield self._remove_path(tx_path)
      raise gen.Return()

    container_size = MAX_SEQUENCE_COUNTER + 1
    automatic_offset = (container_count - 1) * container_size
    txid = self._txid_manual_offset + automatic_offset + tx_node_id

    if txid < 1:
      yield self._remove_path(tx_path)
      raise gen.Return()

    # If the transaction is still valid, return the time it was created.
    if tx_time + MAX_TX_DURATION >= time.time():
      raise gen.Return(tx_time)

    yield self._batch_resolver.resolve(txid, composite_indexes)
    yield self._remove_path(tx_path)
    yield self._batch_resolver.cleanup(txid)

  @gen.coroutine
  def _fetch_and_clean(self, worker_index, worker_count):
    """ Cleans up expired transactions.

    Args:
      worker_index: An integer specifying this worker's index.
      worker_count: An integer specifying the number of total workers.
    Returns:
      A float specifying the time of the oldest valid transaction as a unix
      timestamp.
    """
    self._txids_cleaned = 0
    self._oldest_valid_tx_time = time.time()

    children = []
    for index, container in enumerate(self._containers):
      container_path = '/'.join([self._project_node, container])
      new_children = yield self._tornado_zk.get_children(container_path)

      if not new_children and index < len(self._containers) - 1:
        self._inactive_containers.add(container)

      children.extend(['/'.join([container_path, node])
                       for node in new_children])

    logger.debug(
      'Found {} transaction IDs for {}'.format(len(children), self.project_id))

    if not children:
      raise gen.Return(self._oldest_valid_tx_time)

    # Refresh these each time so that the indexes are fresh.
    encoded_indexes = yield self._thread_pool.submit(
      self._db_access.get_indices, self.project_id)
    composite_indexes = [CompositeIndex(index) for index in encoded_indexes]

    for tx_path in children:
      tx_node_id = int(tx_path.split('/')[-1].lstrip(COUNTER_NODE_PREFIX))
      # Only resolve transactions that this worker has been assigned.
      if tx_node_id % worker_count != worker_index:
        continue

      yield self._worker_queue.put((tx_path, composite_indexes))

    yield self._worker_queue.join()

    if self._txids_cleaned > 0:
      logger.info('Cleaned up {} expired txids for {}'.format(
        self._txids_cleaned, self.project_id))

    raise gen.Return(self._oldest_valid_tx_time)