Пример #1
0
 def __init__(self):
     self.connected = False
     self.connected_event = Event()
     self.disconnected_event = Event()
     self.presence_queue = Queue()
     self.message_queue = Queue()
     self.error_queue = Queue()
Пример #2
0
class ResolveTimeoutTestCase(AsyncHTTPTestCase):
    def setUp(self):
        self.cleanup_event = Event()
        test = self

        # Dummy Resolver subclass that never finishes.
        class BadResolver(Resolver):
            @gen.coroutine
            def resolve(self, *args, **kwargs):
                yield test.cleanup_event.wait()
                # Return something valid so the test doesn't raise during cleanup.
                return [(socket.AF_INET, ("127.0.0.1", test.get_http_port()))]

        super(ResolveTimeoutTestCase, self).setUp()
        self.http_client = SimpleAsyncHTTPClient(resolver=BadResolver())

    def get_app(self):
        return Application([url("/hello", HelloWorldHandler)])

    def test_resolve_timeout(self):
        with self.assertRaises(HTTPTimeoutError):
            self.fetch("/hello", connect_timeout=0.1, raise_error=True)

        # Let the hanging coroutine clean up after itself
        self.cleanup_event.set()
        self.io_loop.run_sync(lambda: gen.sleep(0))
Пример #3
0
    def __init__(self, pubnub_instance):

        subscription_manager = self

        self._message_queue = Queue()
        self._consumer_event = Event()
        self._cancellation_event = Event()
        self._subscription_lock = Semaphore(1)
        # self._current_request_key_object = None
        self._heartbeat_periodic_callback = None
        self._reconnection_manager = TornadoReconnectionManager(pubnub_instance)

        super(TornadoSubscriptionManager, self).__init__(pubnub_instance)
        self._start_worker()

        class TornadoReconnectionCallback(ReconnectionCallback):
            def on_reconnect(self):
                subscription_manager.reconnect()

                pn_status = PNStatus()
                pn_status.category = PNStatusCategory.PNReconnectedCategory
                pn_status.error = False

                subscription_manager._subscription_status_announced = True
                subscription_manager._listener_manager.announce_status(pn_status)

        self._reconnection_listener = TornadoReconnectionCallback()
        self._reconnection_manager.set_reconnection_listener(self._reconnection_listener)
Пример #4
0
    def test_connect_timeout(self):
        timeout = 0.1

        cleanup_event = Event()
        test = self

        class TimeoutResolver(Resolver):
            async def resolve(self, *args, **kwargs):
                await cleanup_event.wait()
                # Return something valid so the test doesn't raise during shutdown.
                return [(socket.AF_INET, ("127.0.0.1", test.get_http_port()))]

        with closing(self.create_client(resolver=TimeoutResolver())) as client:
            with self.assertRaises(HTTPTimeoutError):
                yield client.fetch(
                    self.get_url("/hello"),
                    connect_timeout=timeout,
                    request_timeout=3600,
                    raise_error=True,
                )

        # Let the hanging coroutine clean up after itself. We need to
        # wait more than a single IOLoop iteration for the SSL case,
        # which logs errors on unexpected EOF.
        cleanup_event.set()
        yield gen.sleep(0.2)
Пример #5
0
 def get(self):
     logging.debug("queuing trigger")
     event = Event()
     self.queue.append(event.set)
     if self.get_argument("wake", "true") == "true":
         self.wake_callback()
     yield event.wait()
Пример #6
0
 def get(self):
     logging.debug("queuing trigger")
     self.queue.append(self.finish)
     if self.get_argument("wake", "true") == "true":
         self.wake_callback()
     never_finish = Event()
     yield never_finish.wait()
Пример #7
0
    def test_http10_no_content_length(self):
        # Regression test for a bug in which can_keep_alive would crash
        # for an HTTP/1.0 (not 1.1) response with no content-length.
        conn = HTTP1Connection(self.client_stream, True)
        self.server_stream.write(b"HTTP/1.0 200 Not Modified\r\n\r\nhello")
        self.server_stream.close()

        event = Event()
        test = self
        body = []

        class Delegate(HTTPMessageDelegate):
            def headers_received(self, start_line, headers):
                test.code = start_line.code

            def data_received(self, data):
                body.append(data)

            def finish(self):
                event.set()

        yield conn.read_response(Delegate())
        yield event.wait()
        self.assertEqual(self.code, 200)
        self.assertEqual(b''.join(body), b'hello')
Пример #8
0
    def test_read_until_regex_max_bytes(self):
        rs, ws = yield self.make_iostream_pair()
        closed = Event()
        rs.set_close_callback(closed.set)
        try:
            # Extra room under the limit
            fut = rs.read_until_regex(b"def", max_bytes=50)
            ws.write(b"abcdef")
            data = yield fut
            self.assertEqual(data, b"abcdef")

            # Just enough space
            fut = rs.read_until_regex(b"def", max_bytes=6)
            ws.write(b"abcdef")
            data = yield fut
            self.assertEqual(data, b"abcdef")

            # Not enough space, but we don't know it until all we can do is
            # log a warning and close the connection.
            with ExpectLog(gen_log, "Unsatisfiable read"):
                rs.read_until_regex(b"def", max_bytes=5)
                ws.write(b"123456")
                yield closed.wait()
        finally:
            ws.close()
            rs.close()
Пример #9
0
class QueueDriver:
    def __init__(self,**settings):
        self.settings = settings
        self._finished = Event()
        self._getters = collections.deque([])  # Futures.
        self._putters = collections.deque([])
        self.initialize(**settings)

    def initialize(self,**settings):
        pass

    def over(self):
        self._finished.set()

    def save(self):
        raise NotImplementedError()

    def get(self):
        raise NotImplementedError()

    def put(self):
        raise NotImplementedError()

    def join(self,timeout=None):
        return self._finished.wait(timeout)
Пример #10
0
class GameHandler(WebSocketHandler):
    def __init__(self, *args, **kwargs):
        super(GameHandler, self).__init__(*args, **kwargs)
        self.game = None
        self.player = None
        self.answered = Event()

    def open(self):
        pass

    def auth(self, data):
        self.game = GameLoader.load(data["game"])
        self.player = data["player"]
        self.game.handlers[self.player] = self
        self.game.introduce(self.player)

    def on_message(self, message):
        data = Dumper.load(message)
        if data["action"] == "auth":
            self.auth(data)
        if self.game is None:
            self.close()
        if data["action"] == "move":
            self.game.take_action(self.player, data)
        if data["action"] == "answer":
            self.answered.set()

    def on_close(self):
        print("WebSocket closed")

    def check_origin(self, origin):
        return True
Пример #11
0
 def start_night(self):
     self.look_own_card_done = Event()
     self.werewolves_wake_up_done = Event()
     yield [
         self.look_own_card(),
         self.werewolves_wake_up(),
         self.seer_wake_up(),
     ]
Пример #12
0
    class Waiter(object):
        def __init__(self):
            self.event = Event()

        @gen.coroutine
        def set(self):
            self.event.set()

        @gen.coroutine
        def wait(self):
            yield self.event.wait()
Пример #13
0
    def _start(self):
        if self.scheduler.status != 'running':
            yield self.scheduler._sync_center()
            self.scheduler.start()

        start_event = Event()
        self.coroutines = [
                self.scheduler.handle_queues(self.scheduler_queue, self.report_queue),
                self.report(start_event)]

        _global_executor[0] = self
        yield start_event.wait()
        logger.debug("Started scheduling coroutines. Synchronized")
Пример #14
0
    def test_idle_after_use(self):
        stream = yield self.connect()
        event = Event()
        stream.set_close_callback(event.set)

        # Use the connection twice to make sure keep-alives are working
        for i in range(2):
            stream.write(b"GET / HTTP/1.1\r\n\r\n")
            yield stream.read_until(b"\r\n\r\n")
            data = yield stream.read_bytes(11)
            self.assertEqual(data, b"Hello world")

        # Now let the timeout trigger and close the connection.
        yield event.wait()
Пример #15
0
 def __init__(self, logger, loop, sqs_client,
              metric_prefix='emitter'):
     self.emitter = sqs_client
     self.logger = logger
     self.loop = loop
     self.metric_prefix = metric_prefix
     self.output_error = Event()
     self.state = RUNNING
     self.sender_tag = 'sender:%s.%s' % (self.__class__.__module__,
                                         self.__class__.__name__)
     self._send_queue = Queue()
     self._should_flush_queue = Event()
     self._flush_handle = None
     self.loop.spawn_callback(self._onSend)
Пример #16
0
 def test_read_until_regex_max_bytes_ignores_extra(self):
     rs, ws = yield self.make_iostream_pair()
     closed = Event()
     rs.set_close_callback(closed.set)
     try:
         # Even though data that matches arrives the same packet that
         # puts us over the limit, we fail the request because it was not
         # found within the limit.
         ws.write(b"abcdef")
         with ExpectLog(gen_log, "Unsatisfiable read"):
             rs.read_until_regex(b"def", max_bytes=5)
             yield closed.wait()
     finally:
         ws.close()
         rs.close()
Пример #17
0
    def test_prepare_curl_callback_stack_context(self):
        exc_info = []
        error_event = Event()

        def error_handler(typ, value, tb):
            exc_info.append((typ, value, tb))
            error_event.set()
            return True

        with ExceptionStackContext(error_handler):
            request = HTTPRequest(self.get_url('/custom_reason'),
                                  prepare_curl_callback=lambda curl: 1 / 0)
        yield [error_event.wait(), self.http_client.fetch(request)]
        self.assertEqual(1, len(exc_info))
        self.assertIs(exc_info[0][0], ZeroDivisionError)
Пример #18
0
 def __init__(self, limit=512):
     self.open = 0
     self.active = 0
     self.limit = limit
     self.available = defaultdict(set)
     self.occupied = defaultdict(set)
     self.event = Event()
Пример #19
0
  def __init__(self, client, path, identifier=None):
    """ Creates an AsyncKazooLock.

    Args:
      client: A KazooClient.
      path: The lock path to use.
      identifier: The name to use for this lock contender. This can be useful
        for querying to see who the current lock contenders are.
    """
    self.client = client
    self.tornado_kazoo = TornadoKazoo(client)
    self.path = path

    # some data is written to the node. this can be queried via
    # contenders() to see who is contending for the lock
    self.data = str(identifier or "").encode('utf-8')
    self.node = None

    self.wake_event = AsyncEvent()

    # props to Netflix Curator for this trick. It is possible for our
    # create request to succeed on the server, but for a failure to
    # prevent us from getting back the full path name. We prefix our
    # lock name with a uuid and can check for its presence on retry.
    self.prefix = uuid.uuid4().hex + self._NODE_NAME
    self.create_path = self.path + "/" + self.prefix

    self.create_tried = False
    self.is_acquired = False
    self.assured_path = False
    self.cancelled = False
    self._retry = AsyncKazooRetry(max_tries=-1)
    self._lock = AsyncLock()
Пример #20
0
    def asyncSetUp(self):
        listener, port = bind_unused_port()
        event = Event()

        def accept_callback(conn, addr):
            self.server_stream = IOStream(conn)
            self.addCleanup(self.server_stream.close)
            event.set()

        add_accept_handler(listener, accept_callback)
        self.client_stream = IOStream(socket.socket())
        self.addCleanup(self.client_stream.close)
        yield [self.client_stream.connect(('127.0.0.1', port)),
               event.wait()]
        self.io_loop.remove_handler(listener)
        listener.close()
Пример #21
0
 def test_open_coroutine(self):
     self.message_sent = Event()
     ws = yield self.ws_connect("/open_coroutine")
     yield ws.write_message("hello")
     self.message_sent.set()
     res = yield ws.read_message()
     self.assertEqual(res, "ok")
Пример #22
0
    def __init__(self, check_time, fx_correlator_object):

        self.instrument = fx_correlator_object
        self.hosts = self.instrument.fhosts + self.instrument.xhosts
        self.selected_host = None
        self.host_index = 0
        self.num_hosts = len(self.hosts)
        self.num_fhosts = len(self.instrument.fhosts)
        self.num_xhosts = len(self.instrument.xhosts)
        # self.num_bhosts = len(self.instrument.bhosts)

        # check config file if bhosts or xhosts

        if check_time == -1:
            self.check_time = float(self.instrument.configd['FxCorrelator']['monitor_loop_time'])
        else:
            self.check_time = check_time

        # set up periodic engine monitoring
        self.instrument_monitoring_loop_enabled = IOLoopEvent()
        self.instrument_monitoring_loop_enabled.clear()
        self.instrument_monitoring_loop_cb = None

        self.f_eng_board_monitoring_dict_prev = {}
        self.x_eng_board_monitoring_dict_prev = {}
        self.b_eng_board_monitoring_dict_prev = {}

        self.disabled_fhosts = []
        self.disabled_xhosts = []
        self.disabled_bhosts = []

        # some other useful bits of info
        self.n_chans = self.instrument.n_chans
        self.chans_per_xhost = self.n_chans / self.num_xhosts
Пример #23
0
  def __init__(self, zk_client, datastore_access, perform_admin=False):
    """ Creates a new IndexManager.

    Args:
      zk_client: A kazoo.client.KazooClient object.
      datastore_access: A DatastoreDistributed object.
      perform_admin: A boolean specifying whether or not to perform admin
        operations.
    """
    self.projects = {}
    self._wake_event = AsyncEvent()
    self._zk_client = zk_client
    self.admin_lock = AsyncKazooLock(self._zk_client, self.ADMIN_LOCK_NODE)

    # TODO: Refactor so that this dependency is not needed.
    self._ds_access = datastore_access

    self._zk_client.ensure_path('/appscale/projects')
    self._zk_client.ChildrenWatch('/appscale/projects', self._update_projects)

    # Since this manager can be used synchronously, ensure that the projects
    # are populated for this IOLoop iteration.
    project_ids = self._zk_client.get_children('/appscale/projects')
    self._update_projects_sync(project_ids)

    if perform_admin:
      IOLoop.current().spawn_callback(self._contend_for_admin_lock)
Пример #24
0
  def __init__(self, project_id, zk_client, index_manager, datastore_access):
    """ Creates a new ProjectIndexManager.

    Args:
      project_id: A string specifying a project ID.
      zk_client: A KazooClient.
      update_callback: A function that should be called with the project ID
        and index list every time the indexes get updated.
      index_manager: An IndexManager used for checking lock status.
      datastore_access: A DatastoreDistributed object.
    """
    self.project_id = project_id
    self.indexes_node = '/appscale/projects/{}/indexes'.format(self.project_id)
    self.active = True
    self.update_event = AsyncEvent()

    self._creation_times = {}
    self._index_manager = index_manager
    self._zk_client = zk_client
    self._ds_access = datastore_access

    self._zk_client.DataWatch(self.indexes_node, self._update_indexes_watch)

    # Since this manager can be used synchronously, ensure that the indexes
    # are populated for this IOLoop iteration.
    try:
      encoded_indexes = self._zk_client.get(self.indexes_node)[0]
    except NoNodeError:
      encoded_indexes = '[]'

    self.indexes = [DatastoreIndex.from_dict(self.project_id, index)
                    for index in json.loads(encoded_indexes)]
Пример #25
0
 def _start(self):
     yield self.scheduler._sync_center()
     self._scheduler_start_event = Event()
     self.coroutines = [self.scheduler.start(), self.report()]
     _global_executors.add(self)
     yield self._scheduler_start_event.wait()
     logger.debug("Started scheduling coroutines. Synchronized")
Пример #26
0
    def __init__(self, routes, node, pipe):
        """
        Application instantiates and registers handlers for each message type,
        and routes messages to the pre-instantiated instances of each message handler

        :param routes: list of tuples in the form of (<message type str>, <MessageHandler class>)
        :param node: Node instance of the local node
        :param pipe: Instance of multiprocessing.Pipe for communicating with the parent process
        """
        # We don't really have to worry about synchronization
        # so long as we're careful about explicit context switching
        self.nodes = {node.node_id: node}

        self.local_node = node
        self.handlers = {}

        self.tcpclient = TCPClient()

        self.gossip_inbox = Queue()
        self.gossip_outbox = Queue()

        self.sequence_number = 0

        if routes:
            self.add_handlers(routes)

        self.pipe = pipe
        self.ioloop = IOLoop.current()

        self.add_node_event = Event()
Пример #27
0
 def test_read_until_regex_max_bytes_inline(self):
     rs, ws = yield self.make_iostream_pair()
     closed = Event()
     rs.set_close_callback(closed.set)
     try:
         # Similar to the error case in the previous test, but the
         # ws writes first so rs reads are satisfied
         # inline.  For consistency with the out-of-line case, we
         # do not raise the error synchronously.
         ws.write(b"123456")
         with ExpectLog(gen_log, "Unsatisfiable read"):
             rs.read_until_regex(b"def", max_bytes=5)
             yield closed.wait()
     finally:
         ws.close()
         rs.close()
Пример #28
0
 def test_open_coroutine(self):
     self.message_sent = Event()
     ws = yield self.ws_connect('/open_coroutine')
     yield ws.write_message('hello')
     self.message_sent.set()
     res = yield ws.read_message()
     self.assertEqual(res, 'ok')
     yield self.close(ws)
Пример #29
0
 def __init__(self, limit=512, deserialize=True):
     self.open = 0
     self.active = 0
     self.limit = limit
     self.available = defaultdict(set)
     self.occupied = defaultdict(set)
     self.deserialize = deserialize
     self.event = Event()
Пример #30
0
    def _start(self, timeout=3, **kwargs):
        if isinstance(self._start_arg, Scheduler):
            self.scheduler = self._start_arg
            self.center = self._start_arg.center
        if isinstance(self._start_arg, str):
            ip, port = tuple(self._start_arg.split(':'))
            self._start_arg = (ip, int(port))
        if isinstance(self._start_arg, tuple):
            r = coerce_to_rpc(self._start_arg, timeout=timeout)
            try:
                ident = yield r.identity()
            except (StreamClosedError, OSError):
                raise IOError("Could not connect to %s:%d" % self._start_arg)
            if ident['type'] == 'Center':
                self.center = r
                self.scheduler = Scheduler(self.center, loop=self.loop,
                                           **kwargs)
                self.scheduler.listen(0)
            elif ident['type'] == 'Scheduler':
                self.scheduler = r
                self.scheduler_stream = yield connect(*self._start_arg)
                yield write(self.scheduler_stream, {'op': 'register-client',
                                                    'client': self.id})
                if 'center' in ident:
                    cip, cport = ident['center']
                    self.center = rpc(ip=cip, port=cport)
                else:
                    self.center = self.scheduler
            else:
                raise ValueError("Unknown Type")

        if isinstance(self.scheduler, Scheduler):
            if self.scheduler.status != 'running':
                yield self.scheduler.sync_center()
                self.scheduler.start(0)
            self.scheduler_queue = Queue()
            self.report_queue = Queue()
            self.coroutines.append(self.scheduler.handle_queues(
                self.scheduler_queue, self.report_queue))

        start_event = Event()
        self.coroutines.append(self._handle_report(start_event))

        _global_executor[0] = self
        yield start_event.wait()
        logger.debug("Started scheduling coroutines. Synchronized")
Пример #31
0
class Server:
    """ Dask Distributed Server

    Superclass for endpoints in a distributed cluster, such as Worker
    and Scheduler objects.

    **Handlers**

    Servers define operations with a ``handlers`` dict mapping operation names
    to functions.  The first argument of a handler function will be a ``Comm``
    for the communication established with the client.  Other arguments
    will receive inputs from the keys of the incoming message which will
    always be a dictionary.

    >>> def pingpong(comm):
    ...     return b'pong'

    >>> def add(comm, x, y):
    ...     return x + y

    >>> handlers = {'ping': pingpong, 'add': add}
    >>> server = Server(handlers)  # doctest: +SKIP
    >>> server.listen('tcp://0.0.0.0:8000')  # doctest: +SKIP

    **Message Format**

    The server expects messages to be dictionaries with a special key, `'op'`
    that corresponds to the name of the operation, and other key-value pairs as
    required by the function.

    So in the example above the following would be good messages.

    *  ``{'op': 'ping'}``
    *  ``{'op': 'add', 'x': 10, 'y': 20}``

    """

    default_ip = ""
    default_port = 0

    def __init__(
        self,
        handlers,
        blocked_handlers=None,
        stream_handlers=None,
        connection_limit=512,
        deserialize=True,
        io_loop=None,
    ):
        self.handlers = {
            "identity": self.identity,
            "connection_stream": self.handle_stream,
        }
        self.handlers.update(handlers)
        if blocked_handlers is None:
            blocked_handlers = dask.config.get(
                "distributed.%s.blocked-handlers" %
                type(self).__name__.lower(), [])
        self.blocked_handlers = blocked_handlers
        self.stream_handlers = {}
        self.stream_handlers.update(stream_handlers or {})

        self.id = type(self).__name__ + "-" + str(uuid.uuid4())
        self._address = None
        self._listen_address = None
        self._port = None
        self._comms = {}
        self.deserialize = deserialize
        self.monitor = SystemMonitor()
        self.counters = None
        self.digests = None
        self.events = None
        self.event_counts = None
        self._ongoing_coroutines = weakref.WeakSet()
        self._event_finished = Event()

        self.listeners = []
        self.io_loop = io_loop or IOLoop.current()
        self.loop = self.io_loop

        if not hasattr(self.io_loop, "profile"):
            ref = weakref.ref(self.io_loop)

            if hasattr(self.io_loop, "asyncio_loop"):

                def stop():
                    loop = ref()
                    return loop is None or loop.asyncio_loop.is_closed()

            else:

                def stop():
                    loop = ref()
                    return loop is None or loop._closing

            self.io_loop.profile = profile.watch(
                omit=("profile.py", "selectors.py"),
                interval=dask.config.get(
                    "distributed.worker.profile.interval"),
                cycle=dask.config.get("distributed.worker.profile.cycle"),
                stop=stop,
            )

        # Statistics counters for various events
        with ignoring(ImportError):
            from .counter import Digest

            self.digests = defaultdict(partial(Digest, loop=self.io_loop))

        from .counter import Counter

        self.counters = defaultdict(partial(Counter, loop=self.io_loop))
        self.events = defaultdict(lambda: deque(maxlen=10000))
        self.event_counts = defaultdict(lambda: 0)

        self.periodic_callbacks = dict()

        pc = PeriodicCallback(self.monitor.update, 500, io_loop=self.io_loop)
        self.periodic_callbacks["monitor"] = pc

        self._last_tick = time()
        pc = PeriodicCallback(
            self._measure_tick,
            parse_timedelta(dask.config.get("distributed.admin.tick.interval"),
                            default="ms") * 1000,
            io_loop=self.io_loop,
        )
        self.periodic_callbacks["tick"] = pc

        self.thread_id = 0

        def set_thread_ident():
            self.thread_id = threading.get_ident()

        self.io_loop.add_callback(set_thread_ident)

        self.__stopped = False

    async def finished(self):
        """ Wait until the server has finished """
        await self._event_finished.wait()

    def start_periodic_callbacks(self):
        """ Start Periodic Callbacks consistently

        This starts all PeriodicCallbacks stored in self.periodic_callbacks if
        they are not yet running.  It does this safely on the IOLoop.
        """
        self._last_tick = time()

        def start_pcs():
            for pc in self.periodic_callbacks.values():
                if not pc.is_running():
                    pc.start()

        self.io_loop.add_callback(start_pcs)

    def stop(self):
        if not self.__stopped:
            self.__stopped = True
            for listener in self.listeners:
                # Delay closing the server socket until the next IO loop tick.
                # Otherwise race conditions can appear if an event handler
                # for an accept() call is already scheduled by the IO loop,
                # raising EBADF.
                # The demonstrator for this is Worker.terminate(), which
                # closes the server socket in response to an incoming message.
                # See https://github.com/tornadoweb/tornado/issues/2069
                self.io_loop.add_callback(listener.stop)

    @property
    def listener(self):
        if self.listeners:
            return self.listeners[0]
        else:
            return None

    def _measure_tick(self):
        now = time()
        diff = now - self._last_tick
        self._last_tick = now
        if diff > tick_maximum_delay:
            logger.info(
                "Event loop was unresponsive in %s for %.2fs.  "
                "This is often caused by long-running GIL-holding "
                "functions or moving large chunks of data. "
                "This can cause timeouts and instability.",
                type(self).__name__,
                diff,
            )
        if self.digests is not None:
            self.digests["tick-duration"].add(diff)

    def log_event(self, name, msg):
        msg["time"] = time()
        if isinstance(name, list):
            for n in name:
                self.events[n].append(msg)
                self.event_counts[n] += 1
        else:
            self.events[name].append(msg)
            self.event_counts[name] += 1

    @property
    def address(self):
        """
        The address this Server can be contacted on.
        """
        if not self._address:
            if self.listener is None:
                raise ValueError("cannot get address of non-running Server")
            self._address = self.listener.contact_address
        return self._address

    @property
    def listen_address(self):
        """
        The address this Server is listening on.  This may be a wildcard
        address such as `tcp://0.0.0.0:1234`.
        """
        if not self._listen_address:
            if self.listener is None:
                raise ValueError(
                    "cannot get listen address of non-running Server")
            self._listen_address = self.listener.listen_address
        return self._listen_address

    @property
    def port(self):
        """
        The port number this Server is listening on.

        This will raise ValueError if the Server is listening on a
        non-IP based protocol.
        """
        if not self._port:
            _, self._port = get_address_host_port(self.address)
        return self._port

    def identity(self, comm=None):
        return {"type": type(self).__name__, "id": self.id}

    async def listen(self, port_or_addr=None, listen_args=None):
        if port_or_addr is None:
            port_or_addr = self.default_port
        if isinstance(port_or_addr, int):
            addr = unparse_host_port(self.default_ip, port_or_addr)
        elif isinstance(port_or_addr, tuple):
            addr = unparse_host_port(*port_or_addr)
        else:
            addr = port_or_addr
            assert isinstance(addr, str)
        listener = listen(
            addr,
            self.handle_comm,
            deserialize=self.deserialize,
            connection_args=listen_args,
        )
        await listener.start()
        self.listeners.append(listener)

    async def handle_comm(self, comm, shutting_down=shutting_down):
        """ Dispatch new communications to coroutine-handlers

        Handlers is a dictionary mapping operation names to functions or
        coroutines.

            {'get_data': get_data,
             'ping': pingpong}

        Coroutines should expect a single Comm object.
        """
        if self.__stopped:
            comm.abort()
            return
        address = comm.peer_address
        op = None

        logger.debug("Connection from %r to %s", address, type(self).__name__)
        self._comms[comm] = op
        try:
            while True:
                try:
                    msg = await comm.read()
                    logger.debug("Message from %r: %s", address, msg)
                except EnvironmentError as e:
                    if not shutting_down():
                        logger.debug(
                            "Lost connection to %r while reading message: %s."
                            " Last operation: %s",
                            address,
                            e,
                            op,
                        )
                    break
                except Exception as e:
                    logger.exception(e)
                    await comm.write(error_message(e, status="uncaught-error"))
                    continue
                if not isinstance(msg, dict):
                    raise TypeError(
                        "Bad message type.  Expected dict, got\n  " + str(msg))

                try:
                    op = msg.pop("op")
                except KeyError:
                    raise ValueError(
                        "Received unexpected message without 'op' key: " +
                        str(msg))
                if self.counters is not None:
                    self.counters["op"].add(op)
                self._comms[comm] = op
                serializers = msg.pop("serializers", None)
                close_desired = msg.pop("close", False)
                reply = msg.pop("reply", True)
                if op == "close":
                    if reply:
                        await comm.write("OK")
                    break

                result = None
                try:
                    if op in self.blocked_handlers:
                        _msg = (
                            "The '{op}' handler has been explicitly disallowed "
                            "in {obj}, possibly due to security concerns.")
                        exc = ValueError(
                            _msg.format(op=op, obj=type(self).__name__))
                        handler = raise_later(exc)
                    else:
                        handler = self.handlers[op]
                except KeyError:
                    logger.warning(
                        "No handler %s found in %s",
                        op,
                        type(self).__name__,
                        exc_info=True,
                    )
                else:
                    if serializers is not None and has_keyword(
                            handler, "serializers"):
                        msg["serializers"] = serializers  # add back in

                    logger.debug("Calling into handler %s", handler.__name__)
                    try:
                        result = handler(comm, **msg)
                        if isawaitable(result):
                            result = asyncio.ensure_future(result)
                            self._ongoing_coroutines.add(result)
                            result = await result
                    except (CommClosedError, CancelledError) as e:
                        if self.status == "running":
                            logger.info("Lost connection to %r: %s", address,
                                        e)
                        break
                    except Exception as e:
                        logger.exception(e)
                        result = error_message(e, status="uncaught-error")

                if reply and result != "dont-reply":
                    try:
                        await comm.write(result, serializers=serializers)
                    except (EnvironmentError, TypeError) as e:
                        logger.debug(
                            "Lost connection to %r while sending result for op %r: %s",
                            address,
                            op,
                            e,
                        )
                        break
                msg = result = None
                if close_desired:
                    await comm.close()
                if comm.closed():
                    break

        finally:
            del self._comms[comm]
            if not shutting_down() and not comm.closed():
                try:
                    comm.abort()
                except Exception as e:
                    logger.error("Failed while closing connection to %r: %s",
                                 address, e)

    async def handle_stream(self, comm, extra=None, every_cycle=[]):
        extra = extra or {}
        logger.info("Starting established connection")

        io_error = None
        closed = False
        try:
            while not closed:
                msgs = await comm.read()
                if not isinstance(msgs, (tuple, list)):
                    msgs = (msgs, )

                if not comm.closed():
                    for msg in msgs:
                        if msg == "OK":  # from close
                            break
                        op = msg.pop("op")
                        if op:
                            if op == "close-stream":
                                closed = True
                                break
                            handler = self.stream_handlers[op]
                            if is_coroutine_function(handler):
                                self.loop.add_callback(handler,
                                                       **merge(extra, msg))
                            else:
                                handler(**merge(extra, msg))
                        else:
                            logger.error("odd message %s", msg)
                    await asyncio.sleep(0)

                for func in every_cycle:
                    func()

        except (CommClosedError, EnvironmentError) as e:
            io_error = e
        except Exception as e:
            logger.exception(e)
            if LOG_PDB:
                import pdb

                pdb.set_trace()
            raise
        finally:
            await comm.close()
            assert comm.closed()

    @gen.coroutine
    def close(self):
        for pc in self.periodic_callbacks.values():
            pc.stop()
        for listener in self.listeners:
            self.listener.stop()
        for i in range(20):  # let comms close naturally for a second
            if not self._comms:
                break
            else:
                yield asyncio.sleep(0.05)
        yield [comm.close() for comm in self._comms]  # then forcefully close
        for cb in self._ongoing_coroutines:
            cb.cancel()
        for i in range(10):
            if all(cb.cancelled() for c in self._ongoing_coroutines):
                break
            else:
                yield asyncio.sleep(0.01)

        self._event_finished.set()
Пример #32
0
class UpdateManager:
    def __init__(self, config):
        self.server = config.get_server()
        self.config = config
        self.config.read_supplemental_config(SUPPLEMENTAL_CFG_PATH)
        self.repo_debug = config.getboolean('enable_repo_debug', False)
        auto_refresh_enabled = config.getboolean('enable_auto_refresh', False)
        self.distro = config.get('distro', "debian").lower()
        if self.distro not in SUPPORTED_DISTROS:
            raise config.error(f"Unsupported distro: {self.distro}")
        if self.repo_debug:
            logging.warn("UPDATE MANAGER: REPO DEBUG ENABLED")
        env = sys.executable
        mooncfg = self.config[f"update_manager static {self.distro} moonraker"]
        self.updaters = {
            "system": PackageUpdater(self),
            "moonraker": GitUpdater(self, mooncfg, MOONRAKER_PATH, env)
        }
        self.current_update = None
        # TODO: Check for client config in [update_manager].  This is
        # deprecated and will be removed.
        client_repo = config.get("client_repo", None)
        if client_repo is not None:
            client_path = config.get("client_path")
            name = client_repo.split("/")[-1]
            self.updaters[name] = WebUpdater(self, {
                'repo': client_repo,
                'path': client_path
            })
        client_sections = self.config.get_prefix_sections(
            "update_manager client")
        for section in client_sections:
            cfg = self.config[section]
            name = section.split()[-1]
            if name in self.updaters:
                raise config.error("Client repo named %s already added" %
                                   (name, ))
            client_type = cfg.get("type")
            if client_type == "git_repo":
                self.updaters[name] = GitUpdater(self, cfg)
            elif client_type == "web":
                self.updaters[name] = WebUpdater(self, cfg)
            else:
                raise config.error("Invalid type '%s' for section [%s]" %
                                   (client_type, section))

        # GitHub API Rate Limit Tracking
        self.gh_rate_limit = None
        self.gh_limit_remaining = None
        self.gh_limit_reset_time = None
        self.gh_init_evt = Event()
        self.cmd_request_lock = Lock()
        self.is_refreshing = False

        # Auto Status Refresh
        self.last_auto_update_time = 0
        self.refresh_cb = None
        if auto_refresh_enabled:
            self.refresh_cb = PeriodicCallback(self._handle_auto_refresh,
                                               UPDATE_REFRESH_INTERVAL_MS)
            self.refresh_cb.start()

        AsyncHTTPClient.configure(None, defaults=dict(user_agent="Moonraker"))
        self.http_client = AsyncHTTPClient()

        self.server.register_endpoint("/machine/update/moonraker", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/klipper", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/system", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/client", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/status", ["GET"],
                                      self._handle_status_request)
        self.server.register_notification("update_manager:update_response")
        self.server.register_notification("update_manager:update_refreshed")

        # Register Ready Event
        self.server.register_event_handler("server:klippy_identified",
                                           self._set_klipper_repo)
        # Initialize GitHub API Rate Limits and configured updaters
        IOLoop.current().spawn_callback(self._initalize_updaters,
                                        list(self.updaters.values()))

    async def _initalize_updaters(self, initial_updaters):
        self.is_refreshing = True
        await self._init_api_rate_limit()
        for updater in initial_updaters:
            if isinstance(updater, PackageUpdater):
                ret = updater.refresh(False)
            else:
                ret = updater.refresh()
            if asyncio.iscoroutine(ret):
                await ret
        self.is_refreshing = False

    async def _set_klipper_repo(self):
        kinfo = self.server.get_klippy_info()
        if not kinfo:
            logging.info("No valid klippy info received")
            return
        kpath = kinfo['klipper_path']
        env = kinfo['python_path']
        kupdater = self.updaters.get('klipper', None)
        if kupdater is not None and kupdater.repo_path == kpath and \
                kupdater.env == env:
            # Current Klipper Updater is valid
            return
        kcfg = self.config[f"update_manager static {self.distro} klipper"]
        self.updaters['klipper'] = GitUpdater(self, kcfg, kpath, env)
        await self.updaters['klipper'].refresh()

    async def _check_klippy_printing(self):
        klippy_apis = self.server.lookup_plugin('klippy_apis')
        result = await klippy_apis.query_objects({'print_stats': None},
                                                 default={})
        pstate = result.get('print_stats', {}).get('state', "")
        return pstate.lower() == "printing"

    async def _handle_auto_refresh(self):
        if await self._check_klippy_printing():
            # Don't Refresh during a print
            logging.info("Klippy is printing, auto refresh aborted")
            return
        cur_time = time.time()
        cur_hour = time.localtime(cur_time).tm_hour
        time_diff = cur_time - self.last_auto_update_time
        # Update packages if it has been more than 12 hours
        # and the local time is between 12AM and 5AM
        if time_diff < MIN_REFRESH_TIME or cur_hour >= MAX_PKG_UPDATE_HOUR:
            # Not within the update time window
            return
        self.last_auto_update_time = cur_time
        vinfo = {}
        need_refresh_all = not self.is_refreshing
        async with self.cmd_request_lock:
            self.is_refreshing = True
            try:
                for name, updater in list(self.updaters.items()):
                    if need_refresh_all:
                        ret = updater.refresh()
                        if asyncio.iscoroutine(ret):
                            await ret
                    if hasattr(updater, "get_update_status"):
                        vinfo[name] = updater.get_update_status()
            except Exception:
                logging.exception("Unable to Refresh Status")
                return
            finally:
                self.is_refreshing = False
        uinfo = {
            'version_info': vinfo,
            'github_rate_limit': self.gh_rate_limit,
            'github_requests_remaining': self.gh_limit_remaining,
            'github_limit_reset_time': self.gh_limit_reset_time,
            'busy': self.current_update is not None
        }
        self.server.send_event("update_manager:update_refreshed", uinfo)

    async def _handle_update_request(self, web_request):
        if await self._check_klippy_printing():
            raise self.server.error("Update Refused: Klippy is printing")
        app = web_request.get_endpoint().split("/")[-1]
        if app == "client":
            app = web_request.get('name')
        inc_deps = web_request.get_boolean('include_deps', False)
        if self.current_update is not None and \
                self.current_update[0] == app:
            return f"Object {app} is currently being updated"
        updater = self.updaters.get(app, None)
        if updater is None:
            raise self.server.error(f"Updater {app} not available")
        async with self.cmd_request_lock:
            self.current_update = (app, id(web_request))
            try:
                await updater.update(inc_deps)
            except Exception as e:
                self.notify_update_response(f"Error updating {app}")
                self.notify_update_response(str(e), is_complete=True)
                raise
            finally:
                self.current_update = None
        return "ok"

    async def _handle_status_request(self, web_request):
        check_refresh = web_request.get_boolean('refresh', False)
        # Don't refresh if a print is currently in progress or
        # if an update is in progress.  Just return the current
        # state
        if self.current_update is not None or \
                await self._check_klippy_printing():
            check_refresh = False
        need_refresh = False
        if check_refresh:
            # If there is an outstanding request processing a
            # refresh, we don't need to do it again.
            need_refresh = not self.is_refreshing
            await self.cmd_request_lock.acquire()
            self.is_refreshing = True
        vinfo = {}
        try:
            for name, updater in list(self.updaters.items()):
                await updater.check_initialized(120.)
                if need_refresh:
                    ret = updater.refresh()
                    if asyncio.iscoroutine(ret):
                        await ret
                if hasattr(updater, "get_update_status"):
                    vinfo[name] = updater.get_update_status()
        except Exception:
            raise
        finally:
            if check_refresh:
                self.is_refreshing = False
                self.cmd_request_lock.release()
        return {
            'version_info': vinfo,
            'github_rate_limit': self.gh_rate_limit,
            'github_requests_remaining': self.gh_limit_remaining,
            'github_limit_reset_time': self.gh_limit_reset_time,
            'busy': self.current_update is not None
        }

    async def execute_cmd(self, cmd, timeout=10., notify=False, retries=1):
        shell_command = self.server.lookup_plugin('shell_command')
        cb = self.notify_update_response if notify else None
        scmd = shell_command.build_shell_command(cmd, callback=cb)
        while retries:
            if await scmd.run(timeout=timeout, verbose=notify):
                break
            retries -= 1
        if not retries:
            raise self.server.error("Shell Command Error")

    async def execute_cmd_with_response(self, cmd, timeout=10.):
        shell_command = self.server.lookup_plugin('shell_command')
        scmd = shell_command.build_shell_command(cmd, None)
        result = await scmd.run_with_response(timeout, retries=5)
        if result is None:
            raise self.server.error(f"Error Running Command: {cmd}")
        return result

    async def _init_api_rate_limit(self):
        url = "https://api.github.com/rate_limit"
        while 1:
            try:
                resp = await self.github_api_request(url, is_init=True)
                core = resp['resources']['core']
                self.gh_rate_limit = core['limit']
                self.gh_limit_remaining = core['remaining']
                self.gh_limit_reset_time = core['reset']
            except Exception:
                logging.exception("Error Initializing GitHub API Rate Limit")
                await tornado.gen.sleep(30.)
            else:
                reset_time = time.ctime(self.gh_limit_reset_time)
                logging.info(
                    "GitHub API Rate Limit Initialized\n"
                    f"Rate Limit: {self.gh_rate_limit}\n"
                    f"Rate Limit Remaining: {self.gh_limit_remaining}\n"
                    f"Rate Limit Reset Time: {reset_time}, "
                    f"Seconds Since Epoch: {self.gh_limit_reset_time}")
                break
        self.gh_init_evt.set()

    async def github_api_request(self, url, etag=None, is_init=False):
        if not is_init:
            timeout = time.time() + 30.
            try:
                await self.gh_init_evt.wait(timeout)
            except Exception:
                raise self.server.error("Timeout while waiting for GitHub "
                                        "API Rate Limit initialization")
        if self.gh_limit_remaining == 0:
            curtime = time.time()
            if curtime < self.gh_limit_reset_time:
                raise self.server.error(
                    f"GitHub Rate Limit Reached\nRequest: {url}\n"
                    f"Limit Reset Time: {time.ctime(self.gh_limit_remaining)}")
        headers = {"Accept": "application/vnd.github.v3+json"}
        if etag is not None:
            headers['If-None-Match'] = etag
        retries = 5
        while retries:
            try:
                timeout = time.time() + 10.
                fut = self.http_client.fetch(url,
                                             headers=headers,
                                             connect_timeout=5.,
                                             request_timeout=5.,
                                             raise_error=False)
                resp = await tornado.gen.with_timeout(timeout, fut)
            except Exception:
                retries -= 1
                msg = f"Error Processing GitHub API request: {url}"
                if not retries:
                    raise self.server.error(msg)
                logging.exception(msg)
                await tornado.gen.sleep(1.)
                continue
            etag = resp.headers.get('etag', None)
            if etag is not None:
                if etag[:2] == "W/":
                    etag = etag[2:]
            logging.info("GitHub API Request Processed\n"
                         f"URL: {url}\n"
                         f"Response Code: {resp.code}\n"
                         f"Response Reason: {resp.reason}\n"
                         f"ETag: {etag}")
            if resp.code == 403:
                raise self.server.error(
                    f"Forbidden GitHub Request: {resp.reason}")
            elif resp.code == 304:
                logging.info(f"Github Request not Modified: {url}")
                return None
            if resp.code != 200:
                retries -= 1
                if not retries:
                    raise self.server.error(
                        f"Github Request failed: {resp.code} {resp.reason}")
                logging.info(
                    f"Github request error, {retries} retries remaining")
                await tornado.gen.sleep(1.)
                continue
            # Update rate limit on return success
            if 'X-Ratelimit-Limit' in resp.headers and not is_init:
                self.gh_rate_limit = int(resp.headers['X-Ratelimit-Limit'])
                self.gh_limit_remaining = int(
                    resp.headers['X-Ratelimit-Remaining'])
                self.gh_limit_reset_time = float(
                    resp.headers['X-Ratelimit-Reset'])
            decoded = json.loads(resp.body)
            decoded['etag'] = etag
            return decoded

    async def http_download_request(self, url):
        retries = 5
        while retries:
            try:
                timeout = time.time() + 130.
                fut = self.http_client.fetch(
                    url,
                    headers={"Accept": "application/zip"},
                    connect_timeout=5.,
                    request_timeout=120.)
                resp = await tornado.gen.with_timeout(timeout, fut)
            except Exception:
                retries -= 1
                logging.exception("Error Processing Download")
                if not retries:
                    raise
                await tornado.gen.sleep(1.)
                continue
            return resp.body

    def notify_update_response(self, resp, is_complete=False):
        resp = resp.strip()
        if isinstance(resp, bytes):
            resp = resp.decode()
        notification = {
            'message': resp,
            'application': None,
            'proc_id': None,
            'complete': is_complete
        }
        if self.current_update is not None:
            notification['application'] = self.current_update[0]
            notification['proc_id'] = self.current_update[1]
        self.server.send_event("update_manager:update_response", notification)

    def close(self):
        self.http_client.close()
        if self.refresh_cb is not None:
            self.refresh_cb.stop()
Пример #33
0
class PackageUpdater:
    def __init__(self, umgr):
        self.server = umgr.server
        self.execute_cmd = umgr.execute_cmd
        self.execute_cmd_with_response = umgr.execute_cmd_with_response
        self.notify_update_response = umgr.notify_update_response
        self.available_packages = []
        self.init_evt = Event()
        self.refresh_condition = None

    async def refresh(self, fetch_packages=True):
        # TODO: Use python-apt python lib rather than command line for updates
        if self.refresh_condition is None:
            self.refresh_condition = Condition()
        else:
            self.refresh_condition.wait()
            return
        try:
            if fetch_packages:
                await self.execute_cmd(f"{APT_CMD} update",
                                       timeout=300.,
                                       retries=3)
            res = await self.execute_cmd_with_response("apt list --upgradable",
                                                       timeout=60.)
            pkg_list = [p.strip() for p in res.split("\n") if p.strip()]
            if pkg_list:
                pkg_list = pkg_list[2:]
                self.available_packages = [
                    p.split("/", maxsplit=1)[0] for p in pkg_list
                ]
            pkg_list = "\n".join(self.available_packages)
            logging.info(
                f"Detected {len(self.available_packages)} package updates:"
                f"\n{pkg_list}")
        except Exception:
            logging.exception("Error Refreshing System Packages")
        self.init_evt.set()
        self.refresh_condition.notify_all()
        self.refresh_condition = None

    async def check_initialized(self, timeout=None):
        if self.init_evt.is_set():
            return
        if timeout is not None:
            timeout = IOLoop.current().time() + timeout
        await self.init_evt.wait(timeout)

    async def update(self, *args):
        await self.check_initialized(20.)
        if self.refresh_condition is not None:
            self.refresh_condition.wait()
        self.notify_update_response("Updating packages...")
        try:
            await self.execute_cmd(f"{APT_CMD} update",
                                   timeout=300.,
                                   notify=True)
            await self.execute_cmd(f"{APT_CMD} upgrade --yes",
                                   timeout=3600.,
                                   notify=True)
        except Exception:
            raise self.server.error("Error updating system packages")
        self.available_packages = []
        self.notify_update_response("Package update finished...",
                                    is_complete=True)

    def get_update_status(self):
        return {
            'package_count': len(self.available_packages),
            'package_list': self.available_packages
        }
class ClientRequestHandler(RequestHandler):
    """A RequestHandler for client requests to this MHS."""
    def initialize(self, sender: Sender, callbacks: Dict[str, Callable[[str],
                                                                       None]],
                   async_timeout: int):
        """Initialise this request handler with the provided configuration values.

        :param sender: The sender to use to send messages.
        :param callbacks: The dictionary of callbacks to use when awaiting an asynchronous response.
        :param async_timeout: The amount of time (in seconds) to wait for an asynchronous response.
        """
        self.sender = sender
        self.callbacks = callbacks
        self.async_response_received = Event()
        self.async_timeout = async_timeout

    async def post(self):
        logging.debug("Client POST received: %s", self.request)

        interaction_name = self.request.uri[1:]

        async_message_id = None
        try:
            interaction_is_async, async_message_id, message = self.sender.prepare_message(
                interaction_name, self.request.body.decode())

            if interaction_is_async:
                self.callbacks[async_message_id] = self._write_async_response
                logging.debug(
                    "Added callback for asynchronous message with ID '%s'",
                    async_message_id)

            immediate_response = self.sender.send_message(
                interaction_name, message)

            if interaction_is_async:
                await self._pause_request(async_message_id)
            else:
                # No async response expected. Just return the response to our initial request.
                self._write_response(immediate_response)

        except UnknownInteractionError:
            raise HTTPError(404, "Unknown interaction ID: %s",
                            interaction_name)
        finally:
            if async_message_id in self.callbacks:
                del self.callbacks[async_message_id]

    async def _pause_request(self, async_message_id):
        """Pause the incoming request until an asynchronous response is received (or a timeout is hit).

        :param async_message_id: The ID of the request that expects an asynchronous response.
        :raises: An HTTPError if we timed out waiting for an asynchronous response.
        """
        try:
            logging.debug(
                "Waiting for asynchronous response to message with ID '%s'",
                async_message_id)
            await self.async_response_received.wait(
                timedelta(seconds=self.async_timeout))
        except TimeoutError:
            raise HTTPError(
                log_message=
                f"Timed out waiting for a response to message {async_message_id}"
            )

    def _write_response(self, message: str) -> None:
        """Write the given message to the response.

        :param message: The message to write to the response.
        """
        self.set_header("Content-Type", "text/xml")
        self.write(message)

    def _write_async_response(self, message: str) -> None:
        """Write the given message to the response and notify anyone waiting that this has been done.

        :param message: The message to write to the response.
        """
        logging.debug("Received asynchronous response containing message '%s'",
                      message)
        self._write_response(message)
        self.async_response_received.set()
Пример #35
0
class Queue(object):
    """Coordinate producer and consumer coroutines.

    If maxsize is 0 (the default) the queue size is unbounded.
    """
    def __init__(self, maxsize=0):
        if maxsize is None:
            raise TypeError("maxsize can't be None")

        if maxsize < 0:
            raise ValueError("maxsize can't be negative")

        self._maxsize = maxsize
        self._init()
        self._getters = collections.deque([])  # Futures.
        self._putters = collections.deque([])  # Pairs of (item, Future).
        self._unfinished_tasks = 0
        self._finished = Event()
        self._finished.set()

    @property
    def maxsize(self):
        """Number of items allowed in the queue."""
        return self._maxsize

    def qsize(self):
        """Number of items in the queue."""
        return len(self._queue)

    def empty(self):
        return not self._queue

    def full(self):
        if self.maxsize == 0:
            return False
        else:
            return self.qsize() >= self.maxsize

    def put(self, item, timeout=None):
        """Put an item into the queue, perhaps waiting until there is room.

        Returns a Future, which raises `tornado.gen.TimeoutError` after a
        timeout.
        """
        try:
            self.put_nowait(item)
        except QueueFull:
            future = Future()
            self._putters.append((item, future))
            _set_timeout(future, timeout)
            return future
        else:
            return gen._null_future

    def put_nowait(self, item):
        """Put an item into the queue without blocking.

        If no free slot is immediately available, raise `QueueFull`.
        """
        self._consume_expired()
        if self._getters:
            assert self.empty(), "queue non-empty, why are getters waiting?"
            getter = self._getters.popleft()
            self.__put_internal(item)
            getter.set_result(self._get())
        elif self.full():
            raise QueueFull
        else:
            self.__put_internal(item)

    def get(self, timeout=None):
        """Remove and return an item from the queue.

        Returns a Future which resolves once an item is available, or raises
        `tornado.gen.TimeoutError` after a timeout.
        """
        future = Future()
        try:
            future.set_result(self.get_nowait())
        except QueueEmpty:
            self._getters.append(future)
            _set_timeout(future, timeout)
        return future

    def get_nowait(self):
        """Remove and return an item from the queue without blocking.

        Return an item if one is immediately available, else raise
        `QueueEmpty`.
        """
        self._consume_expired()
        if self._putters:
            assert self.full(), "queue not full, why are putters waiting?"
            item, putter = self._putters.popleft()
            self.__put_internal(item)
            putter.set_result(None)
            return self._get()
        elif self.qsize():
            return self._get()
        else:
            raise QueueEmpty

    def task_done(self):
        """Indicate that a formerly enqueued task is complete.

        Used by queue consumers. For each `.get` used to fetch a task, a
        subsequent call to `.task_done` tells the queue that the processing
        on the task is complete.

        If a `.join` is blocking, it resumes when all items have been
        processed; that is, when every `.put` is matched by a `.task_done`.

        Raises `ValueError` if called more times than `.put`.
        """
        if self._unfinished_tasks <= 0:
            raise ValueError('task_done() called too many times')
        self._unfinished_tasks -= 1
        if self._unfinished_tasks == 0:
            self._finished.set()

    def join(self, timeout=None):
        """Block until all items in the queue are processed.

        Returns a Future, which raises `tornado.gen.TimeoutError` after a
        timeout.
        """
        return self._finished.wait(timeout)

    # These three are overridable in subclasses.
    def _init(self):
        self._queue = collections.deque()

    def _get(self):
        return self._queue.popleft()

    def _put(self, item):
        self._queue.append(item)

    # End of the overridable methods.

    def __put_internal(self, item):
        self._unfinished_tasks += 1
        self._finished.clear()
        self._put(item)

    def _consume_expired(self):
        # Remove timed-out waiters.
        while self._putters and self._putters[0][1].done():
            self._putters.popleft()

        while self._getters and self._getters[0].done():
            self._getters.popleft()

    def __repr__(self):
        return '<%s at %s %s>' % (type(self).__name__, hex(
            id(self)), self._format())

    def __str__(self):
        return '<%s %s>' % (type(self).__name__, self._format())

    def _format(self):
        result = 'maxsize=%r' % (self.maxsize, )
        if getattr(self, '_queue', None):
            result += ' queue=%r' % self._queue
        if self._getters:
            result += ' getters[%s]' % len(self._getters)
        if self._putters:
            result += ' putters[%s]' % len(self._putters)
        if self._unfinished_tasks:
            result += ' tasks=%s' % self._unfinished_tasks
        return result
Пример #36
0
class WorkerProcess(object):
    def __init__(
        self,
        worker_args,
        worker_kwargs,
        worker_start_args,
        silence_logs,
        on_exit,
        worker,
        env,
    ):
        self.status = "init"
        self.silence_logs = silence_logs
        self.worker_args = worker_args
        self.worker_kwargs = worker_kwargs
        self.worker_start_args = worker_start_args
        self.on_exit = on_exit
        self.process = None
        self.Worker = worker
        self.env = env

        # Initialized when worker is ready
        self.worker_dir = None
        self.worker_address = None

    @gen.coroutine
    def start(self):
        """
        Ensure the worker process is started.
        """
        enable_proctitle_on_children()
        if self.status == "running":
            raise gen.Return(self.status)
        if self.status == "starting":
            yield self.running.wait()
            raise gen.Return(self.status)

        self.init_result_q = init_q = mp_context.Queue()
        self.child_stop_q = mp_context.Queue()
        uid = uuid.uuid4().hex

        self.process = AsyncProcess(
            target=self._run,
            kwargs=dict(
                worker_args=self.worker_args,
                worker_kwargs=self.worker_kwargs,
                worker_start_args=self.worker_start_args,
                silence_logs=self.silence_logs,
                init_result_q=self.init_result_q,
                child_stop_q=self.child_stop_q,
                uid=uid,
                Worker=self.Worker,
                env=self.env,
            ),
        )
        self.process.daemon = True
        self.process.set_exit_callback(self._on_exit)
        self.running = Event()
        self.stopped = Event()
        self.status = "starting"
        yield self.process.start()
        msg = yield self._wait_until_connected(uid)
        if not msg:
            raise gen.Return(self.status)
        self.worker_address = msg["address"]
        self.worker_dir = msg["dir"]
        assert self.worker_address
        self.status = "running"
        self.running.set()

        init_q.close()

        raise gen.Return(self.status)

    def _on_exit(self, proc):
        if proc is not self.process:
            # Ignore exit of old process instance
            return
        self.mark_stopped()

    def _death_message(self, pid, exitcode):
        assert exitcode is not None
        if exitcode == 255:
            return "Worker process %d was killed by unknown signal" % (pid, )
        elif exitcode >= 0:
            return "Worker process %d exited with status %d" % (pid, exitcode)
        else:
            return "Worker process %d was killed by signal %d" % (pid,
                                                                  -exitcode)

    def is_alive(self):
        return self.process is not None and self.process.is_alive()

    @property
    def pid(self):
        return self.process.pid if self.process and self.process.is_alive(
        ) else None

    def mark_stopped(self):
        if self.status != "stopped":
            r = self.process.exitcode
            assert r is not None
            if r != 0:
                msg = self._death_message(self.process.pid, r)
                logger.warning(msg)
            self.status = "stopped"
            self.stopped.set()
            # Release resources
            self.process.close()
            self.init_result_q = None
            self.child_stop_q = None
            self.process = None
            # Best effort to clean up worker directory
            if self.worker_dir and os.path.exists(self.worker_dir):
                shutil.rmtree(self.worker_dir, ignore_errors=True)
            self.worker_dir = None
            # User hook
            if self.on_exit is not None:
                self.on_exit(r)

    @gen.coroutine
    def kill(self, timeout=2, executor_wait=True):
        """
        Ensure the worker process is stopped, waiting at most
        *timeout* seconds before terminating it abruptly.
        """
        loop = IOLoop.current()
        deadline = loop.time() + timeout

        if self.status == "stopped":
            return
        if self.status == "stopping":
            yield self.stopped.wait()
            return
        assert self.status in ("starting", "running")
        self.status = "stopping"

        process = self.process
        self.child_stop_q.put({
            "op": "stop",
            "timeout": max(0, deadline - loop.time()) * 0.8,
            "executor_wait": executor_wait,
        })
        self.child_stop_q.close()

        while process.is_alive() and loop.time() < deadline:
            yield gen.sleep(0.05)

        if process.is_alive():
            logger.warning(
                "Worker process still alive after %d seconds, killing",
                timeout)
            try:
                yield process.terminate()
            except Exception as e:
                logger.error("Failed to kill worker process: %s", e)

    @gen.coroutine
    def _wait_until_connected(self, uid):
        delay = 0.05
        while True:
            if self.status != "starting":
                return
            try:
                msg = self.init_result_q.get_nowait()
            except Empty:
                yield gen.sleep(delay)
                continue

            if msg["uid"] != uid:  # ensure that we didn't cross queues
                continue

            if "exception" in msg:
                logger.error("Failed while trying to start worker process: %s",
                             msg["exception"])
                yield self.process.join()
                raise msg
            else:
                raise gen.Return(msg)

    @classmethod
    def _run(
        cls,
        worker_args,
        worker_kwargs,
        worker_start_args,
        silence_logs,
        init_result_q,
        child_stop_q,
        uid,
        env,
        Worker,
    ):  # pragma: no cover
        os.environ.update(env)
        try:
            from dask.multiprocessing import initialize_worker_process
        except ImportError:  # old Dask version
            pass
        else:
            initialize_worker_process()

        if silence_logs:
            logger.setLevel(silence_logs)

        IOLoop.clear_instance()
        loop = IOLoop()
        loop.make_current()
        worker = Worker(*worker_args, **worker_kwargs)

        @gen.coroutine
        def do_stop(timeout=5, executor_wait=True):
            try:
                yield worker.close(
                    report=False,
                    nanny=False,
                    executor_wait=executor_wait,
                    timeout=timeout,
                )
            finally:
                loop.stop()

        def watch_stop_q():
            """
            Wait for an incoming stop message and then stop the
            worker cleanly.
            """
            while True:
                try:
                    msg = child_stop_q.get(timeout=1000)
                except Empty:
                    pass
                else:
                    child_stop_q.close()
                    assert msg.pop("op") == "stop"
                    loop.add_callback(do_stop, **msg)
                    break

        t = threading.Thread(target=watch_stop_q,
                             name="Nanny stop queue watch")
        t.daemon = True
        t.start()

        @gen.coroutine
        def run():
            """
            Try to start worker and inform parent of outcome.
            """
            try:
                yield worker._start(*worker_start_args)
            except Exception as e:
                logger.exception("Failed to start worker")
                init_result_q.put({"uid": uid, "exception": e})
                init_result_q.close()
            else:
                assert worker.address
                init_result_q.put({
                    "address": worker.address,
                    "dir": worker.local_dir,
                    "uid": uid
                })
                init_result_q.close()
                yield worker.wait_until_closed()
                logger.info("Worker closed")

        try:
            loop.run_sync(run)
        except TimeoutError:
            # Loop was stopped before wait_until_closed() returned, ignore
            pass
        except KeyboardInterrupt:
            pass
Пример #37
0
class Queue:
    """ Distributed Queue

    This allows multiple clients to share futures or small bits of data between
    each other with a multi-producer/multi-consumer queue.  All metadata is
    sequentialized through the scheduler.

    Elements of the Queue must be either Futures or msgpack-encodable data
    (ints, strings, lists, dicts).  All data is sent through the scheduler so
    it is wise not to send large objects.  To share large objects scatter the
    data and share the future instead.

    .. warning::

       This object is experimental and has known issues in Python 2

    Examples
    --------
    >>> from dask.distributed import Client, Queue  # doctest: +SKIP
    >>> client = Client()  # doctest: +SKIP
    >>> queue = Queue('x')  # doctest: +SKIP
    >>> future = client.submit(f, x)  # doctest: +SKIP
    >>> queue.put(future)  # doctest: +SKIP

    See Also
    --------
    Variable: shared variable between clients
    """

    def __init__(self, name=None, client=None, maxsize=0):
        self.client = client or _get_global_client()
        self.name = name or "queue-" + uuid.uuid4().hex
        self._event_started = Event()
        if self.client.asynchronous or getattr(
            thread_state, "on_event_loop_thread", False
        ):

            async def _create_queue():
                await self.client.scheduler.queue_create(
                    name=self.name, maxsize=maxsize
                )
                self._event_started.set()

            self.client.loop.add_callback(_create_queue)
        else:
            sync(
                self.client.loop,
                self.client.scheduler.queue_create,
                name=self.name,
                maxsize=maxsize,
            )
            self._event_started.set()

    def __await__(self):
        async def _():
            await self._event_started.wait()
            return self

        return _().__await__()

    async def _put(self, value, timeout=None):
        if isinstance(value, Future):
            await self.client.scheduler.queue_put(
                key=tokey(value.key), timeout=timeout, name=self.name
            )
        else:
            await self.client.scheduler.queue_put(
                data=value, timeout=timeout, name=self.name
            )

    def put(self, value, timeout=None, **kwargs):
        """ Put data into the queue """
        return self.client.sync(self._put, value, timeout=timeout, **kwargs)

    def get(self, timeout=None, batch=False, **kwargs):
        """ Get data from the queue

        Parameters
        ----------
        timeout: Number (optional)
            Time in seconds to wait before timing out
        batch: boolean, int (optional)
            If True then return all elements currently waiting in the queue.
            If an integer than return that many elements from the queue
            If False (default) then return one item at a time
         """
        return self.client.sync(self._get, timeout=timeout, batch=batch, **kwargs)

    def qsize(self, **kwargs):
        """ Current number of elements in the queue """
        return self.client.sync(self._qsize, **kwargs)

    async def _get(self, timeout=None, batch=False):
        try:
            resp = await self.client.scheduler.queue_get(
                timeout=timeout, name=self.name, batch=batch
            )
        except gen.TimeoutError:
            raise TimeoutError("Timed out waiting for Queue")

        def process(d):
            if d["type"] == "Future":
                value = Future(d["value"], self.client, inform=True, state=d["state"])
                if d["state"] == "erred":
                    value._state.set_error(d["exception"], d["traceback"])
                self.client._send_to_scheduler(
                    {"op": "queue-future-release", "name": self.name, "key": d["value"]}
                )
            else:
                value = d["value"]

            return value

        if batch is False:
            result = process(resp)
        else:
            result = list(map(process, resp))

        return result

    async def _qsize(self):
        result = await self.client.scheduler.queue_qsize(name=self.name)
        return result

    def close(self):
        if self.client.status == "running":  # TODO: can leave zombie futures
            self.client._send_to_scheduler({"op": "queue_release", "name": self.name})

    def __getstate__(self):
        return (self.name, self.client.scheduler.address)

    def __setstate__(self, state):
        name, address = state
        try:
            client = get_client(address)
            assert client.scheduler.address == address
        except (AttributeError, AssertionError):
            client = Client(address, set_as_default=False)
        self.__init__(name=name, client=client)
Пример #38
0
class MockFitsWriterClient(object):
    """
    Wrapper class for a KATCP client to a EddFitsWriterServer
    """
    def __init__(self, address, record_dest):
        """
        @brief      Construct new instance
                    If record_dest is not empty, create a folder named record_dest and record the received packages there.
        """
        self._address = address
        self.__record_dest = record_dest
        if record_dest:
            if not os.path.isdir(record_dest):
                os.makedirs(record_dest)
        self._ioloop = IOLoop.current()
        self._stop_event = Event()
        self._is_stopped = Condition()
        self._socket = None
        self.__last_package = 0

    def reset_connection(self):
        self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self._socket.setblocking(False)
        try:
            self._socket.connect(self._address)
        except socket.error as error:
            if error.args[0] == errno.EINPROGRESS:
                pass
            else:
                raise error

    @coroutine
    def recv_nbytes(self, nbytes):
        received_bytes = 0
        data = b''
        while received_bytes < nbytes:
            if self._stop_event.is_set():
                raise StopEvent
            try:
                log.debug("Requesting {} bytes".format(nbytes -
                                                       received_bytes))
                current_data = self._socket.recv(nbytes - received_bytes)
                received_bytes += len(current_data)
                data += current_data
                log.debug("Received {} bytes ({} of {} bytes)".format(
                    len(current_data), received_bytes, nbytes))
            except socket.error as error:
                error_id = error.args[0]
                if error_id == errno.EAGAIN or error_id == errno.EWOULDBLOCK:
                    yield sleep(0.1)
                else:
                    log.exception("Unexpected error on socket recv: {}".format(
                        str(error)))
                    raise error
        raise Return(data)

    @coroutine
    def recv_loop(self):
        while not self._stop_event.is_set():
            try:
                header, sections = yield self.recv_packet()
            except StopEvent:
                log.debug("Notifying that recv calls have stopped")
            except Exception as E:
                log.exception("Failure while receiving packet: {}".format(E))

    def start(self):
        self._stop_event.clear()
        self.reset_connection()
        self._ioloop.add_callback(self.recv_loop)

    @coroutine
    def stop(self, timeout=2):
        self._stop_event.set()
        try:
            success = yield self._is_stopped.wait(timeout=self._ioloop.time() +
                                                  timeout)
            if not success:
                raise TimeoutError
        except TimeoutError:
            log.error(("Could not stop the client within "
                       "the {} second limit").format(timeout))
        except Exception:
            log.exception("Fucup")

    @coroutine
    def recv_packet(self):
        log.debug("Receiving packet header")
        raw_header = yield self.recv_nbytes(C.sizeof(FWHeader))
        log.debug("Converting packet header")
        header = FWHeader.from_buffer_copy(raw_header)
        log.info("Received header: {}".format(header))
        if header.timestamp < self.__last_package:
            log.error("Timestamps out of order!")
        else:
            self.__last_package = header.timestamp

        if self.__record_dest:
            filename = os.path.join(self.__record_dest,
                                    "FWP_{}.dat".format(header.timestamp))
            while os.path.isfile(filename):
                log.warning('Filename {} already exists. Add suffix _'.format(
                    filename))
                filename += '_'
            log.info('Recording to file {}'.format(filename))
            ofile = open(filename, 'wb')
            ofile.write(raw_header)

        fw_data_type = header.channel_data_type.strip().upper()
        c_data_type, np_data_type = TYPE_MAP[fw_data_type]
        sections = []
        for section in range(header.nsections):
            log.debug("Receiving section {} of {}".format(
                section + 1, header.nsections))
            raw_section_header = yield self.recv_nbytes(
                C.sizeof(FWSectionHeader))
            if self.__record_dest:
                ofile.write(raw_section_header)

            section_header = FWSectionHeader.from_buffer_copy(
                raw_section_header)
            log.info("Section {} header: {}".format(section, section_header))
            log.debug("Receiving section data")
            raw_bytes = yield self.recv_nbytes(
                C.sizeof(c_data_type) * section_header.nchannels)
            if self.__record_dest:
                ofile.write(raw_bytes)
            data = np.frombuffer(raw_bytes, dtype=np_data_type)
            log.info("Section {} data: {}".format(section, data[:10]))
            sections.append((section_header, data))

        if self.__record_dest:
            ofile.close()
        raise Return((header, sections))
Пример #39
0
class ConnectionPool(object):
    """ A maximum sized pool of Comm objects.

    This provides a connect method that mirrors the normal distributed.connect
    method, but provides connection sharing and tracks connection limits.

    This object provides an ``rpc`` like interface::

        >>> rpc = ConnectionPool(limit=512)
        >>> scheduler = rpc('127.0.0.1:8786')
        >>> workers = [rpc(address) for address ...]

        >>> info = yield scheduler.identity()

    It creates enough comms to satisfy concurrent connections to any
    particular address::

        >>> a, b = yield [scheduler.who_has(), scheduler.has_what()]

    It reuses existing comms so that we don't have to continuously reconnect.

    It also maintains a comm limit to avoid "too many open file handle"
    issues.  Whenever this maximum is reached we clear out all idling comms.
    If that doesn't do the trick then we wait until one of the occupied comms
    closes.

    Parameters
    ----------
    limit: int
        The number of open comms to maintain at once
    deserialize: bool
        Whether or not to deserialize data by default or pass it through
    """
    def __init__(self, limit=512, deserialize=True, connection_args=None):
        self.open = 0  # Total number of open comms
        self.active = 0  # Number of comms currently in use
        self.limit = limit  # Max number of open comms
        # Invariant: len(available) == open - active
        self.available = defaultdict(set)
        # Invariant: len(occupied) == active
        self.occupied = defaultdict(set)
        self.deserialize = deserialize
        self.connection_args = connection_args
        self.event = Event()

    def __repr__(self):
        return "<ConnectionPool: open=%d, active=%d>" % (self.open,
                                                         self.active)

    def __call__(self, addr=None, ip=None, port=None):
        """ Cached rpc objects """
        addr = addr_from_args(addr=addr, ip=ip, port=port)
        return PooledRPCCall(addr, self)

    @gen.coroutine
    def connect(self, addr, timeout=None):
        """
        Get a Comm to the given address.  For internal use.
        """
        available = self.available[addr]
        occupied = self.occupied[addr]
        if available:
            comm = available.pop()
            if not comm.closed():
                self.active += 1
                occupied.add(comm)
                raise gen.Return(comm)
            else:
                self.open -= 1

        while self.open >= self.limit:
            self.event.clear()
            self.collect()
            yield self.event.wait()

        self.open += 1
        try:
            comm = yield connect(addr,
                                 timeout=timeout,
                                 deserialize=self.deserialize,
                                 connection_args=self.connection_args)
        except Exception:
            self.open -= 1
            raise
        self.active += 1
        occupied.add(comm)

        if self.open >= self.limit:
            self.event.clear()

        raise gen.Return(comm)

    def reuse(self, addr, comm):
        """
        Reuse an open communication to the given address.  For internal use.
        """
        self.occupied[addr].remove(comm)
        self.active -= 1
        if comm.closed():
            self.open -= 1
            if self.open < self.limit:
                self.event.set()
        else:
            self.available[addr].add(comm)

    def collect(self):
        """
        Collect open but unused communications, to allow opening other ones.
        """
        logger.info("Collecting unused comms.  open: %d, active: %d",
                    self.open, self.active)
        for addr, comms in self.available.items():
            for comm in comms:
                comm.close()
            comms.clear()
        self.open = self.active
        if self.open < self.limit:
            self.event.set()

    def close(self):
        """
        Close all communications abruptly.
        """
        for comms in self.available.values():
            for comm in comms:
                comm.abort()
        for comms in self.occupied.values():
            for comm in comms:
                comm.abort()
Пример #40
0
the Fibonacci sequence to listeners.
"""
import os
import signal
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.options import define, options
from tornado.locks import Event
from tornado.web import Application, RequestHandler, StaticFileHandler
from tornadose.handlers import EventSource
from tornadose.stores import DataStore

define('port', default=9000)

store = DataStore()
finish = Event()


def fibonacci():
    a, b = 0, 1
    while True:
        yield a
        a, b = a, b


class MainHandler(RequestHandler):
    def get(self):
        self.render("index.html")


@gen.coroutine
Пример #41
0
 def __init__(self, **settings):
     self.settings = settings
     self._finished = Event()
     self._getters = collections.deque([])  # Futures.
     self._putters = collections.deque([])
     self.initialize(**settings)
Пример #42
0
class KeepAliveTest(AsyncHTTPTestCase):
    """Tests various scenarios for HTTP 1.1 keep-alive support.

    These tests don't use AsyncHTTPClient because we want to control
    connection reuse and closing.
    """
    def get_app(self):
        class HelloHandler(RequestHandler):
            def get(self):
                self.finish("Hello world")

            def post(self):
                self.finish("Hello world")

        class LargeHandler(RequestHandler):
            def get(self):
                # 512KB should be bigger than the socket buffers so it will
                # be written out in chunks.
                self.write("".join(chr(i % 256) * 1024 for i in range(512)))

        class TransferEncodingChunkedHandler(RequestHandler):
            @gen.coroutine
            def head(self):
                self.write("Hello world")
                yield self.flush()

        class FinishOnCloseHandler(RequestHandler):
            def initialize(self, cleanup_event):
                self.cleanup_event = cleanup_event

            @gen.coroutine
            def get(self):
                self.flush()
                yield self.cleanup_event.wait()

            def on_connection_close(self):
                # This is not very realistic, but finishing the request
                # from the close callback has the right timing to mimic
                # some errors seen in the wild.
                self.finish("closed")

        self.cleanup_event = Event()
        return Application([
            ("/", HelloHandler),
            ("/large", LargeHandler),
            ("/chunked", TransferEncodingChunkedHandler),
            (
                "/finish_on_close",
                FinishOnCloseHandler,
                dict(cleanup_event=self.cleanup_event),
            ),
        ])

    def setUp(self):
        super().setUp()
        self.http_version = b"HTTP/1.1"

    def tearDown(self):
        # We just closed the client side of the socket; let the IOLoop run
        # once to make sure the server side got the message.
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

        if hasattr(self, "stream"):
            self.stream.close()
        super().tearDown()

    # The next few methods are a crude manual http client
    @gen.coroutine
    def connect(self):
        self.stream = IOStream(socket.socket())
        yield self.stream.connect(("127.0.0.1", self.get_http_port()))

    @gen.coroutine
    def read_headers(self):
        first_line = yield self.stream.read_until(b"\r\n")
        self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
        header_bytes = yield self.stream.read_until(b"\r\n\r\n")
        headers = HTTPHeaders.parse(header_bytes.decode("latin1"))
        raise gen.Return(headers)

    @gen.coroutine
    def read_response(self):
        self.headers = yield self.read_headers()
        body = yield self.stream.read_bytes(int(
            self.headers["Content-Length"]))
        self.assertEqual(b"Hello world", body)

    def close(self):
        self.stream.close()
        del self.stream

    @gen_test
    def test_two_requests(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        self.close()

    @gen_test
    def test_request_close(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
        yield self.read_response()
        data = yield self.stream.read_until_close()
        self.assertTrue(not data)
        self.assertEqual(self.headers["Connection"], "close")
        self.close()

    # keepalive is supported for http 1.0 too, but it's opt-in
    @gen_test
    def test_http10(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
        yield self.read_response()
        data = yield self.stream.read_until_close()
        self.assertTrue(not data)
        self.assertTrue("Connection" not in self.headers)
        self.close()

    @gen_test
    def test_http10_keepalive(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()

    @gen_test
    def test_http10_keepalive_extra_crlf(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(
            b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()

    @gen_test
    def test_pipelined_requests(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        yield self.read_response()
        self.close()

    @gen_test
    def test_pipelined_cancel(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
        # only read once
        yield self.read_response()
        self.close()

    @gen_test
    def test_cancel_during_download(self):
        yield self.connect()
        self.stream.write(b"GET /large HTTP/1.1\r\n\r\n")
        yield self.read_headers()
        yield self.stream.read_bytes(1024)
        self.close()

    @gen_test
    def test_finish_while_closed(self):
        yield self.connect()
        self.stream.write(b"GET /finish_on_close HTTP/1.1\r\n\r\n")
        yield self.read_headers()
        self.close()
        # Let the hanging coroutine clean up after itself
        self.cleanup_event.set()

    @gen_test
    def test_keepalive_chunked(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"POST / HTTP/1.0\r\n"
                          b"Connection: keep-alive\r\n"
                          b"Transfer-Encoding: chunked\r\n"
                          b"\r\n"
                          b"0\r\n"
                          b"\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()

    @gen_test
    def test_keepalive_chunked_head_no_body(self):
        yield self.connect()
        self.stream.write(b"HEAD /chunked HTTP/1.1\r\n\r\n")
        yield self.read_headers()

        self.stream.write(b"HEAD /chunked HTTP/1.1\r\n\r\n")
        yield self.read_headers()
        self.close()
Пример #43
0
 def initialize(self, broker):
     self.event = Event()
     self.broker = broker
Пример #44
0
class GitUpdater:
    def __init__(self, umgr, config, path=None, env=None):
        self.server = umgr.server
        self.execute_cmd = umgr.execute_cmd
        self.execute_cmd_with_response = umgr.execute_cmd_with_response
        self.notify_update_response = umgr.notify_update_response
        self.name = config.get_name().split()[-1]
        self.repo_path = path
        if path is None:
            self.repo_path = config.get('path')
        self.env = config.get("env", env)
        dist_packages = None
        if self.env is not None:
            self.env = os.path.expanduser(self.env)
            dist_packages = config.get('python_dist_packages', None)
            self.python_reqs = os.path.join(self.repo_path,
                                            config.get("requirements"))
        self.origin = config.get("origin").lower()
        self.install_script = config.get('install_script', None)
        if self.install_script is not None:
            self.install_script = os.path.abspath(
                os.path.join(self.repo_path, self.install_script))
        self.venv_args = config.get('venv_args', None)
        self.python_dist_packages = None
        self.python_dist_path = None
        self.env_package_path = None
        if dist_packages is not None:
            self.python_dist_packages = [
                p.strip() for p in dist_packages.split('\n') if p.strip()
            ]
            self.python_dist_path = os.path.abspath(
                config.get('python_dist_path'))
            env_package_path = os.path.abspath(
                os.path.join(os.path.dirname(self.env), "..",
                             config.get('env_package_path')))
            matches = glob.glob(env_package_path)
            if len(matches) == 1:
                self.env_package_path = matches[0]
            else:
                raise config.error("No match for 'env_package_path': %s" %
                                   (env_package_path, ))
        for opt in [
                "repo_path", "env", "python_reqs", "install_script",
                "python_dist_path", "env_package_path"
        ]:
            val = getattr(self, opt)
            if val is None:
                continue
            if not os.path.exists(val):
                raise config.error("Invalid path for option '%s': %s" %
                                   (val, opt))

        self.version = self.cur_hash = "?"
        self.remote_version = self.remote_hash = "?"
        self.init_evt = Event()
        self.refresh_condition = None
        self.debug = umgr.repo_debug
        self.remote = "origin"
        self.branch = "master"
        self.is_valid = self.is_dirty = self.detached = False

    def _get_version_info(self):
        ver_path = os.path.join(self.repo_path, "scripts/version.txt")
        vinfo = {}
        if os.path.isfile(ver_path):
            data = ""
            with open(ver_path, 'r') as f:
                data = f.read()
            try:
                entries = [e.strip() for e in data.split('\n') if e.strip()]
                vinfo = dict([i.split('=') for i in entries])
                vinfo = {
                    k: tuple(re.findall(r"\d+", v))
                    for k, v in vinfo.items()
                }
            except Exception:
                pass
            else:
                self._log_info(f"Version Info Found: {vinfo}")
        vinfo['version'] = tuple(re.findall(r"\d+", self.version))
        return vinfo

    def _log_exc(self, msg, traceback=True):
        log_msg = f"Repo {self.name}: {msg}"
        if traceback:
            logging.exception(log_msg)
        else:
            logging.info(log_msg)
        return self.server.error(msg)

    def _log_info(self, msg):
        log_msg = f"Repo {self.name}: {msg}"
        logging.info(log_msg)

    def _notify_status(self, msg, is_complete=False):
        log_msg = f"Repo {self.name}: {msg}"
        logging.debug(log_msg)
        self.notify_update_response(log_msg, is_complete)

    async def check_initialized(self, timeout=None):
        if self.init_evt.is_set():
            return
        if timeout is not None:
            timeout = IOLoop.current().time() + timeout
        await self.init_evt.wait(timeout)

    async def refresh(self):
        if self.refresh_condition is None:
            self.refresh_condition = Condition()
        else:
            self.refresh_condition.wait()
            return
        try:
            await self._check_version()
        except Exception:
            logging.exception("Error Refreshing git state")
        self.init_evt.set()
        self.refresh_condition.notify_all()
        self.refresh_condition = None

    async def _check_version(self, need_fetch=True):
        self.is_valid = self.detached = False
        self.cur_hash = self.branch = self.remote = "?"
        self.version = self.remote_version = "?"
        try:
            blist = await self.execute_cmd_with_response(
                f"git -C {self.repo_path} branch --list")
            if blist.startswith("fatal:"):
                self._log_info(f"Invalid git repo at path '{self.repo_path}'")
                return
            branch = None
            for b in blist.split("\n"):
                b = b.strip()
                if b[0] == "*":
                    branch = b[2:]
                    break
            if branch is None:
                self._log_info(
                    "Unable to retreive current branch from branch list\n"
                    f"{blist}")
                return
            if "HEAD detached" in branch:
                bparts = branch.split()[-1].strip("()")
                self.remote, self.branch = bparts.split("/")
                self.detached = True
            else:
                self.branch = branch.strip()
                self.remote = await self.execute_cmd_with_response(
                    f"git -C {self.repo_path} config --get"
                    f" branch.{self.branch}.remote")
            if need_fetch:
                await self.execute_cmd(
                    f"git -C {self.repo_path} fetch {self.remote} --prune -q",
                    retries=3)
            remote_url = await self.execute_cmd_with_response(
                f"git -C {self.repo_path} remote get-url {self.remote}")
            cur_hash = await self.execute_cmd_with_response(
                f"git -C {self.repo_path} rev-parse HEAD")
            remote_hash = await self.execute_cmd_with_response(
                f"git -C {self.repo_path} rev-parse "
                f"{self.remote}/{self.branch}")
            repo_version = await self.execute_cmd_with_response(
                f"git -C {self.repo_path} describe --always "
                "--tags --long --dirty")
            remote_version = await self.execute_cmd_with_response(
                f"git -C {self.repo_path} describe {self.remote}/{self.branch}"
                " --always --tags --long")
        except Exception:
            self._log_exc("Error retreiving git info")
            return

        self.is_dirty = repo_version.endswith("dirty")
        versions = []
        for ver in [repo_version, remote_version]:
            tag_version = "?"
            ver_match = re.match(r"v\d+\.\d+\.\d-\d+", ver)
            if ver_match:
                tag_version = ver_match.group()
            versions.append(tag_version)
        self.version, self.remote_version = versions
        self.cur_hash = cur_hash.strip()
        self.remote_hash = remote_hash.strip()
        self._log_info(
            f"Repo Detected:\nPath: {self.repo_path}\nRemote: {self.remote}\n"
            f"Branch: {self.branch}\nRemote URL: {remote_url}\n"
            f"Current SHA: {self.cur_hash}\n"
            f"Remote SHA: {self.remote_hash}\nVersion: {self.version}\n"
            f"Remote Version: {self.remote_version}\n"
            f"Is Dirty: {self.is_dirty}\nIs Detached: {self.detached}")
        if self.debug:
            self.is_valid = True
            self._log_info("Debug enabled, bypassing official repo check")
        elif self.branch == "master" and self.remote == "origin":
            if self.detached:
                self._log_info("Detached HEAD detected, repo invalid")
                return
            remote_url = remote_url.lower()
            if remote_url[-4:] != ".git":
                remote_url += ".git"
            if remote_url == self.origin:
                self.is_valid = True
                self._log_info("Validity check for git repo passed")
            else:
                self._log_info(f"Invalid git origin url '{remote_url}'")
        else:
            self._log_info("Git repo not on offical remote/branch: "
                           f"{self.remote}/{self.branch}")

    async def update(self, update_deps=False):
        await self.check_initialized(20.)
        if self.refresh_condition is not None:
            self.refresh_condition.wait()
        if not self.is_valid:
            raise self._log_exc("Update aborted, repo is not valid", False)
        if self.is_dirty:
            raise self._log_exc("Update aborted, repo is has been modified",
                                False)
        if self.remote_hash == self.cur_hash:
            # No need to update
            return
        self._notify_status("Updating Repo...")
        try:
            if self.detached:
                await self.execute_cmd(
                    f"git -C {self.repo_path} fetch {self.remote} -q",
                    retries=3)
                await self.execute_cmd(f"git -C {self.repo_path} checkout"
                                       f" {self.remote}/{self.branch} -q")
            else:
                await self.execute_cmd(f"git -C {self.repo_path} pull -q",
                                       retries=3)
        except Exception:
            raise self._log_exc("Error running 'git pull'")
        # Check Semantic Versions
        vinfo = self._get_version_info()
        cur_version = vinfo.get('version', ())
        update_deps |= cur_version < vinfo.get('deps_version', ())
        need_env_rebuild = cur_version < vinfo.get('env_version', ())
        if update_deps:
            await self._install_packages()
            await self._update_virtualenv(need_env_rebuild)
        elif need_env_rebuild:
            await self._update_virtualenv(True)
        # Refresh local repo state
        await self._check_version(need_fetch=False)
        if self.name == "moonraker":
            # Launch restart async so the request can return
            # before the server restarts
            self._notify_status("Update Finished...", is_complete=True)
            IOLoop.current().call_later(.1, self.restart_service)
        else:
            await self.restart_service()
            self._notify_status("Update Finished...", is_complete=True)

    async def _install_packages(self):
        if self.install_script is None:
            return
        # Open install file file and read
        inst_path = self.install_script
        if not os.path.isfile(inst_path):
            self._log_info(f"Unable to open install script: {inst_path}")
            return
        with open(inst_path, 'r') as f:
            data = f.read()
        packages = re.findall(r'PKGLIST="(.*)"', data)
        packages = [p.lstrip("${PKGLIST}").strip() for p in packages]
        if not packages:
            self._log_info(f"No packages found in script: {inst_path}")
            return
        # TODO: Log and notify that packages will be installed
        pkgs = " ".join(packages)
        logging.debug(f"Repo {self.name}: Detected Packages: {pkgs}")
        self._notify_status("Installing system dependencies...")
        # Install packages with apt-get
        try:
            await self.execute_cmd(f"{APT_CMD} update",
                                   timeout=300.,
                                   notify=True)
            await self.execute_cmd(f"{APT_CMD} install --yes {pkgs}",
                                   timeout=3600.,
                                   notify=True)
        except Exception:
            self._log_exc("Error updating packages via apt-get")
            return

    async def _update_virtualenv(self, rebuild_env=False):
        if self.env is None:
            return
        # Update python dependencies
        bin_dir = os.path.dirname(self.env)
        env_path = os.path.normpath(os.path.join(bin_dir, ".."))
        if rebuild_env:
            self._notify_status(f"Creating virtualenv at: {env_path}...")
            if os.path.exists(env_path):
                shutil.rmtree(env_path)
            try:
                await self.execute_cmd(
                    f"virtualenv {self.venv_args} {env_path}", timeout=300.)
            except Exception:
                self._log_exc(f"Error creating virtualenv")
                return
            if not os.path.exists(self.env):
                raise self._log_exc("Failed to create new virtualenv", False)
        reqs = self.python_reqs
        if not os.path.isfile(reqs):
            self._log_exc(f"Invalid path to requirements_file '{reqs}'")
            return
        pip = os.path.join(bin_dir, "pip")
        self._notify_status("Updating python packages...")
        try:
            await self.execute_cmd(f"{pip} install -r {reqs}",
                                   timeout=1200.,
                                   notify=True,
                                   retries=3)
        except Exception:
            self._log_exc("Error updating python requirements")
        self._install_python_dist_requirements()

    def _install_python_dist_requirements(self):
        dist_reqs = self.python_dist_packages
        if dist_reqs is None:
            return
        dist_path = self.python_dist_path
        site_path = self.env_package_path
        for pkg in dist_reqs:
            for f in os.listdir(dist_path):
                if f.startswith(pkg):
                    src = os.path.join(dist_path, f)
                    dest = os.path.join(site_path, f)
                    self._notify_status(f"Linking to dist package: {pkg}")
                    if os.path.islink(dest):
                        os.remove(dest)
                    elif os.path.exists(dest):
                        self._notify_status(
                            f"Error symlinking dist package: {pkg}, "
                            f"file already exists: {dest}")
                        continue
                    os.symlink(src, dest)
                    break

    async def restart_service(self):
        self._notify_status("Restarting Service...")
        try:
            await self.execute_cmd(f"sudo systemctl restart {self.name}")
        except Exception:
            raise self._log_exc("Error restarting service")

    def get_update_status(self):
        return {
            'remote_alias': self.remote,
            'branch': self.branch,
            'version': self.version,
            'remote_version': self.remote_version,
            'current_hash': self.cur_hash,
            'remote_hash': self.remote_hash,
            'is_dirty': self.is_dirty,
            'is_valid': self.is_valid,
            'detached': self.detached,
            'debug_enabled': self.debug
        }
Пример #45
0
class Queue(Generic[_T]):
    """Coordinate producer and consumer coroutines.

    If maxsize is 0 (the default) the queue size is unbounded.

    .. testcode::

        from tornado import gen
        from tornado.ioloop import IOLoop
        from tornado.queues import Queue

        q = Queue(maxsize=2)

        async def consumer():
            async for item in q:
                try:
                    print('Doing work on %s' % item)
                    await gen.sleep(0.01)
                finally:
                    q.task_done()

        async def producer():
            for item in range(5):
                await q.put(item)
                print('Put %s' % item)

        async def main():
            # Start consumer without waiting (since it never finishes).
            IOLoop.current().spawn_callback(consumer)
            await producer()     # Wait for producer to put all tasks.
            await q.join()       # Wait for consumer to finish all tasks.
            print('Done')

        IOLoop.current().run_sync(main)

    .. testoutput::

        Put 0
        Put 1
        Doing work on 0
        Put 2
        Doing work on 1
        Put 3
        Doing work on 2
        Put 4
        Doing work on 3
        Doing work on 4
        Done


    In versions of Python without native coroutines (before 3.5),
    ``consumer()`` could be written as::

        @gen.coroutine
        def consumer():
            while True:
                item = yield q.get()
                try:
                    print('Doing work on %s' % item)
                    yield gen.sleep(0.01)
                finally:
                    q.task_done()

    .. versionchanged:: 4.3
       Added ``async for`` support in Python 3.5.

    """

    # Exact type depends on subclass. Could be another generic
    # parameter and use protocols to be more precise here.
    _queue = None  # type: Any

    def __init__(self, maxsize: int = 0) -> None:
        if maxsize is None:
            raise TypeError("maxsize can't be None")

        if maxsize < 0:
            raise ValueError("maxsize can't be negative")

        self._maxsize = maxsize
        self._init()
        self._getters = collections.deque([])  # type: Deque[Future[_T]]
        self._putters = collections.deque(
            [])  # type: Deque[Tuple[_T, Future[None]]]
        self._unfinished_tasks = 0
        self._finished = Event()
        self._finished.set()

    @property
    def maxsize(self) -> int:
        """Number of items allowed in the queue."""
        return self._maxsize

    def qsize(self) -> int:
        """Number of items in the queue."""
        return len(self._queue)

    def empty(self) -> bool:
        return not self._queue

    def full(self) -> bool:
        if self.maxsize == 0:
            return False
        else:
            return self.qsize() >= self.maxsize

    def put(self,
            item: _T,
            timeout: Union[float,
                           datetime.timedelta] = None) -> "Future[None]":
        """Put an item into the queue, perhaps waiting until there is room.

        Returns a Future, which raises `tornado.util.TimeoutError` after a
        timeout.

        ``timeout`` may be a number denoting a time (on the same
        scale as `tornado.ioloop.IOLoop.time`, normally `time.time`), or a
        `datetime.timedelta` object for a deadline relative to the
        current time.
        """
        future = Future()  # type: Future[None]
        try:
            self.put_nowait(item)
        except QueueFull:
            self._putters.append((item, future))
            _set_timeout(future, timeout)
        else:
            future.set_result(None)
        return future

    def put_nowait(self, item: _T) -> None:
        """Put an item into the queue without blocking.

        If no free slot is immediately available, raise `QueueFull`.
        """
        self._consume_expired()
        if self._getters:
            assert self.empty(), "queue non-empty, why are getters waiting?"
            getter = self._getters.popleft()
            self.__put_internal(item)
            future_set_result_unless_cancelled(getter, self._get())
        elif self.full():
            raise QueueFull
        else:
            self.__put_internal(item)

    def get(self,
            timeout: Union[float, datetime.timedelta] = None) -> Awaitable[_T]:
        """Remove and return an item from the queue.

        Returns an awaitable which resolves once an item is available, or raises
        `tornado.util.TimeoutError` after a timeout.

        ``timeout`` may be a number denoting a time (on the same
        scale as `tornado.ioloop.IOLoop.time`, normally `time.time`), or a
        `datetime.timedelta` object for a deadline relative to the
        current time.

        .. note::

           The ``timeout`` argument of this method differs from that
           of the standard library's `queue.Queue.get`. That method
           interprets numeric values as relative timeouts; this one
           interprets them as absolute deadlines and requires
           ``timedelta`` objects for relative timeouts (consistent
           with other timeouts in Tornado).

        """
        future = Future()  # type: Future[_T]
        try:
            future.set_result(self.get_nowait())
        except QueueEmpty:
            self._getters.append(future)
            _set_timeout(future, timeout)
        return future

    def get_nowait(self) -> _T:
        """Remove and return an item from the queue without blocking.

        Return an item if one is immediately available, else raise
        `QueueEmpty`.
        """
        self._consume_expired()
        if self._putters:
            assert self.full(), "queue not full, why are putters waiting?"
            item, putter = self._putters.popleft()
            self.__put_internal(item)
            future_set_result_unless_cancelled(putter, None)
            return self._get()
        elif self.qsize():
            return self._get()
        else:
            raise QueueEmpty

    def task_done(self) -> None:
        """Indicate that a formerly enqueued task is complete.

        Used by queue consumers. For each `.get` used to fetch a task, a
        subsequent call to `.task_done` tells the queue that the processing
        on the task is complete.

        If a `.join` is blocking, it resumes when all items have been
        processed; that is, when every `.put` is matched by a `.task_done`.

        Raises `ValueError` if called more times than `.put`.
        """
        if self._unfinished_tasks <= 0:
            raise ValueError("task_done() called too many times")
        self._unfinished_tasks -= 1
        if self._unfinished_tasks == 0:
            self._finished.set()

    def join(
            self,
            timeout: Union[float,
                           datetime.timedelta] = None) -> Awaitable[None]:
        """Block until all items in the queue are processed.

        Returns an awaitable, which raises `tornado.util.TimeoutError` after a
        timeout.
        """
        return self._finished.wait(timeout)

    def __aiter__(self) -> _QueueIterator[_T]:
        return _QueueIterator(self)

    # These three are overridable in subclasses.
    def _init(self) -> None:
        self._queue = collections.deque()

    def _get(self) -> _T:
        return self._queue.popleft()

    def _put(self, item: _T) -> None:
        self._queue.append(item)

    # End of the overridable methods.

    def __put_internal(self, item: _T) -> None:
        self._unfinished_tasks += 1
        self._finished.clear()
        self._put(item)

    def _consume_expired(self) -> None:
        # Remove timed-out waiters.
        while self._putters and self._putters[0][1].done():
            self._putters.popleft()

        while self._getters and self._getters[0].done():
            self._getters.popleft()

    def __repr__(self) -> str:
        return "<%s at %s %s>" % (type(self).__name__, hex(
            id(self)), self._format())

    def __str__(self) -> str:
        return "<%s %s>" % (type(self).__name__, self._format())

    def _format(self) -> str:
        result = "maxsize=%r" % (self.maxsize, )
        if getattr(self, "_queue", None):
            result += " queue=%r" % self._queue
        if self._getters:
            result += " getters[%s]" % len(self._getters)
        if self._putters:
            result += " putters[%s]" % len(self._putters)
        if self._unfinished_tasks:
            result += " tasks=%s" % self._unfinished_tasks
        return result
Пример #46
0
class WorkerProcess(object):
    def __init__(self, worker_args, worker_kwargs, worker_start_args,
                 silence_logs, on_exit):
        self.status = 'init'
        self.silence_logs = silence_logs
        self.worker_args = worker_args
        self.worker_kwargs = worker_kwargs
        self.worker_start_args = worker_start_args
        self.on_exit = on_exit
        self.process = None

        # Initialized when worker is ready
        self.worker_dir = None
        self.worker_address = None

    @gen.coroutine
    def start(self):
        """
        Ensure the worker process is started.
        """
        if self.status == 'running':
            return
        if self.status == 'starting':
            yield self.running.wait()
            return

        self.init_result_q = mp_context.Queue()
        self.child_stop_q = mp_context.Queue()
        self.process = AsyncProcess(
            target=self._run,
            kwargs=dict(worker_args=self.worker_args,
                        worker_kwargs=self.worker_kwargs,
                        worker_start_args=self.worker_start_args,
                        silence_logs=self.silence_logs,
                        init_result_q=self.init_result_q,
                        child_stop_q=self.child_stop_q),
        )
        self.process.daemon = True
        self.process.set_exit_callback(self._on_exit)
        self.running = Event()
        self.stopped = Event()
        self.status = 'starting'
        yield self.process.start()
        if self.status == 'starting':
            yield self._wait_until_running()

    def _on_exit(self, proc):
        if proc is not self.process:
            # Ignore exit of old process instance
            return
        self.mark_stopped()

    def _death_message(self, pid, exitcode):
        assert exitcode is not None
        if exitcode == 255:
            return "Worker process %d was killed by unknown signal" % (pid, )
        elif exitcode >= 0:
            return "Worker process %d exited with status %d" % (
                pid,
                exitcode,
            )
        else:
            return "Worker process %d was killed by signal %d" % (
                pid,
                -exitcode,
            )

    def is_alive(self):
        return self.process is not None and self.process.is_alive()

    @property
    def pid(self):
        return (self.process.pid
                if self.process and self.process.is_alive() else None)

    def mark_stopped(self):
        if self.status != 'stopped':
            r = self.process.exitcode
            assert r is not None
            if r != 0:
                msg = self._death_message(self.process.pid, r)
                logger.warning(msg)
            self.status = 'stopped'
            self.stopped.set()
            # Release resources
            self.process.close()
            self.init_result_q = None
            self.child_stop_q = None
            self.process = None
            # Best effort to clean up worker directory
            if self.worker_dir and os.path.exists(self.worker_dir):
                shutil.rmtree(self.worker_dir, ignore_errors=True)
            self.worker_dir = None
            # User hook
            if self.on_exit is not None:
                self.on_exit(r)

    @gen.coroutine
    def kill(self, timeout=2):
        """
        Ensure the worker process is stopped, waiting at most
        *timeout* seconds before terminating it abruptly.
        """
        loop = IOLoop.current()
        deadline = loop.time() + timeout

        if self.status == 'stopped':
            return
        if self.status == 'stopping':
            yield self.stopped.wait()
            return
        assert self.status in ('starting', 'running')
        self.status = 'stopping'

        process = self.process
        self.child_stop_q.put({
            'op': 'stop',
            'timeout': max(0, deadline - loop.time()) * 0.8,
        })

        while process.is_alive() and loop.time() < deadline:
            yield gen.sleep(0.05)

        if process.is_alive():
            logger.warning(
                "Worker process still alive after %d seconds, killing",
                timeout)
            try:
                yield process.terminate()
            except Exception as e:
                logger.error("Failed to kill worker process: %s", e)

    @gen.coroutine
    def _wait_until_running(self):
        delay = 0.05
        while True:
            if self.status != 'starting':
                raise ValueError("Worker not started")
            try:
                msg = self.init_result_q.get_nowait()
            except Empty:
                yield gen.sleep(delay)
                continue

            if isinstance(msg, Exception):
                yield self.process.join()
                raise msg
            else:
                self.worker_address = msg['address']
                self.worker_dir = msg['dir']
                assert self.worker_address
                self.status = 'running'
                self.running.set()
                raise gen.Return(msg)

    @classmethod
    def _run(cls, worker_args, worker_kwargs, worker_start_args, silence_logs,
             init_result_q, child_stop_q):  # pragma: no cover
        from distributed import Worker

        try:
            from dask.multiprocessing import initialize_worker_process
        except ImportError:  # old Dask version
            pass
        else:
            initialize_worker_process()

        if silence_logs:
            logger.setLevel(silence_logs)

        IOLoop.clear_instance()
        loop = IOLoop()
        loop.make_current()
        worker = Worker(*worker_args, **worker_kwargs)

        @gen.coroutine
        def do_stop(timeout):
            try:
                yield worker._close(report=False, nanny=False)
            finally:
                loop.stop()

        def watch_stop_q():
            """
            Wait for an incoming stop message and then stop the
            worker cleanly.
            """
            while True:
                try:
                    msg = child_stop_q.get(timeout=1000)
                except Empty:
                    pass
                else:
                    assert msg['op'] == 'stop'
                    loop.add_callback(do_stop, msg['timeout'])
                    break

        t = threading.Thread(target=watch_stop_q,
                             name="Nanny stop queue watch")
        t.daemon = True
        t.start()

        @gen.coroutine
        def run():
            """
            Try to start worker and inform parent of outcome.
            """
            try:
                yield worker._start(*worker_start_args)
            except Exception as e:
                logger.exception("Failed to start worker")
                init_result_q.put(e)
            else:
                assert worker.address
                init_result_q.put({
                    'address': worker.address,
                    'dir': worker.local_dir
                })
                yield worker.wait_until_closed()
                logger.info("Worker closed")

        try:
            loop.run_sync(run)
        except TimeoutError:
            # Loop was stopped before wait_until_closed() returned, ignore
            pass
        except KeyboardInterrupt:
            pass
Пример #47
0
class AsyncTaskManager(object):
    """
    Aucote uses asynchronous task executed in ioloop. Some of them,
    especially scanners, should finish before ioloop will stop

    This class should be accessed by instance class method, which returns global instance of task manager

    """
    _instances = {}

    TASKS_POLITIC_WAIT = 0
    TASKS_POLITIC_KILL_WORKING_FIRST = 1
    TASKS_POLITIC_KILL_PROPORTIONS = 2
    TASKS_POLITIC_KILL_WORKING = 3

    def __init__(self, parallel_tasks=10):
        self._shutdown_condition = Event()
        self._stop_condition = Event()
        self._cron_tasks = {}
        self._parallel_tasks = parallel_tasks
        self._tasks = Queue()
        self._task_workers = {}
        self._events = {}
        self._limit = self._parallel_tasks
        self._next_task_number = 0
        self._toucan_keys = {}

    @classmethod
    def instance(cls, name=None, **kwargs):
        """
        Return instance of AsyncTaskManager

        Returns:
            AsyncTaskManager

        """
        if cls._instances.get(name) is None:
            cls._instances[name] = AsyncTaskManager(**kwargs)
        return cls._instances[name]

    @property
    def shutdown_condition(self):
        """
        Event which is resolved if every job is done and AsyncTaskManager is ready to shutdown

        Returns:
            Event
        """
        return self._shutdown_condition

    def start(self):
        """
        Start CronTabCallback tasks

        Returns:
            None

        """
        for task in self._cron_tasks.values():
            task.start()

        for number in range(self._parallel_tasks):
            self._task_workers[number] = IOLoop.current().add_callback(
                partial(self.process_tasks, number))

        self._next_task_number = self._parallel_tasks

    def add_crontab_task(self, task, cron, event=None):
        """
        Add function to scheduler and execute at cron time

        Args:
            task (function):
            cron (str): crontab value
            event (Event): event which prevent from running task with similar aim, eg. security scans

        Returns:
            None

        """
        if event is not None:
            event = self._events.setdefault(event, Event())
        self._cron_tasks[task] = AsyncCrontabTask(cron, task, event)

    @gen.coroutine
    def stop(self):
        """
        Stop CronTabCallback tasks and wait on them to finish

        Returns:
            None

        """
        for task in self._cron_tasks.values():
            task.stop()
        IOLoop.current().add_callback(self._prepare_shutdown)
        yield [self._stop_condition.wait(), self._tasks.join()]
        self._shutdown_condition.set()

    def _prepare_shutdown(self):
        """
        Check if ioloop can be stopped

        Returns:
            None

        """
        if any(task.is_running() for task in self._cron_tasks.values()):
            IOLoop.current().add_callback(self._prepare_shutdown)
            return

        self._stop_condition.set()

    def clear(self):
        """
        Clear list of tasks

        Returns:
            None

        """
        self._cron_tasks = {}
        self._shutdown_condition.clear()
        self._stop_condition.clear()

    async def process_tasks(self, number):
        """
        Execute queue. Every task in executed in separated thread (_Executor)

        """
        log.info("Starting worker %s", number)
        while True:
            try:
                item = self._tasks.get_nowait()
                try:
                    log.debug("Worker %s: starting %s", number, item)
                    thread = _Executor(task=item, number=number)
                    self._task_workers[number] = thread
                    thread.start()

                    while thread.is_alive():
                        await sleep(0.5)
                except:
                    log.exception("Worker %s: exception occurred", number)
                finally:
                    log.debug("Worker %s: %s finished", number, item)
                    self._tasks.task_done()
                    tasks_per_scan = (
                        '{}: {}'.format(scanner, len(tasks))
                        for scanner, tasks in self.tasks_by_scan.items())
                    log.debug("Tasks left in queue: %s (%s)",
                              self.unfinished_tasks, ', '.join(tasks_per_scan))
                    self._task_workers[number] = None
            except QueueEmpty:
                await gen.sleep(0.5)
                if self._stop_condition.is_set() and self._tasks.empty():
                    return
            finally:
                if self._limit < len(self._task_workers):
                    break

        del self._task_workers[number]

        log.info("Closing worker %s", number)

    def add_task(self, task):
        """
        Add task to the queue

        Args:
            task:

        Returns:
            None

        """
        self._tasks.put(task)

    @property
    def unfinished_tasks(self):
        """
        Task which are still processed or in queue

        Returns:
            int

        """
        return self._tasks._unfinished_tasks

    @property
    def tasks_by_scan(self):
        """
        Returns queued tasks grouped by scan
        """
        tasks = self._tasks._queue

        return_value = {}

        for task in tasks:
            return_value.setdefault(task.context.scanner.NAME, []).append(task)

        return return_value

    @property
    def cron_tasks(self):
        """
        List of cron tasks

        Returns:
            list

        """
        return self._cron_tasks.values()

    def cron_task(self, name):
        for task in self._cron_tasks.values():
            if task.func.NAME == name:
                return task

    def change_throttling_toucan(self, key, value):
        self.change_throttling(value)

    def change_throttling(self, new_value):
        """
        Change throttling value. Keeps throttling value between 0 and 1.

        Behaviour of algorithm is described in docs/throttling.md

        Only working tasks are closing here. Idle workers are stop by themselves

        """
        if new_value > 1:
            new_value = 1
        if new_value < 0:
            new_value = 0

        new_value = round(new_value * 100) / 100

        old_limit = self._limit
        self._limit = round(self._parallel_tasks * float(new_value))

        working_tasks = [
            number for number, task in self._task_workers.items()
            if task is not None
        ]
        current_tasks = len(self._task_workers)

        task_politic = cfg['service.scans.task_politic']

        if task_politic == self.TASKS_POLITIC_KILL_WORKING_FIRST:
            tasks_to_kill = current_tasks - self._limit
        elif task_politic == self.TASKS_POLITIC_KILL_PROPORTIONS:
            tasks_to_kill = round((old_limit - self._limit) *
                                  len(working_tasks) / self._parallel_tasks)
        elif task_politic == self.TASKS_POLITIC_KILL_WORKING:
            tasks_to_kill = (old_limit - self._limit) - (
                len(self._task_workers) - len(working_tasks))
        else:
            tasks_to_kill = 0

        log.debug('%s tasks will be killed', tasks_to_kill)

        for number in working_tasks:
            if tasks_to_kill <= 0:
                break
            self._task_workers[number].stop()
            tasks_to_kill -= 1

        self._limit = round(self._parallel_tasks * float(new_value))

        current_tasks = len(self._task_workers)

        for number in range(self._limit - current_tasks):
            self._task_workers[self._next_task_number] = None
            IOLoop.current().add_callback(
                partial(self.process_tasks, self._next_task_number))
            self._next_task_number += 1
Пример #48
0
class DebugpyClient:
    def __init__(self, log, debugpy_stream, event_callback):
        self.log = log
        self.debugpy_stream = debugpy_stream
        self.event_callback = event_callback
        self.message_queue = DebugpyMessageQueue(self._forward_event, self.log)
        self.debugpy_host = '127.0.0.1'
        self.debugpy_port = -1
        self.routing_id = None
        self.wait_for_attach = True
        self.init_event = Event()
        self.init_event_seq = -1

    def _get_endpoint(self):
        host, port = self.get_host_port()
        return 'tcp://' + host + ':' + str(port)

    def _forward_event(self, msg):
        if msg['event'] == 'initialized':
            self.init_event.set()
            self.init_event_seq = msg['seq']
        self.event_callback(msg)

    def _send_request(self, msg):
        if self.routing_id is None:
            self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID)
        content = jsonapi.dumps(msg)
        content_length = str(len(content))
        buf = (DebugpyMessageQueue.HEADER + content_length +
               DebugpyMessageQueue.SEPARATOR).encode('ascii')
        buf += content
        self.log.debug("DEBUGPYCLIENT:")
        self.log.debug(self.routing_id)
        self.log.debug(buf)
        self.debugpy_stream.send_multipart((self.routing_id, buf))

    async def _wait_for_response(self):
        # Since events are never pushed to the message_queue
        # we can safely assume the next message in queue
        # will be an answer to the previous request
        return await self.message_queue.get_message()

    async def _handle_init_sequence(self):
        # 1] Waits for initialized event
        await self.init_event.wait()

        # 2] Sends configurationDone request
        configurationDone = {
            'type': 'request',
            'seq': int(self.init_event_seq) + 1,
            'command': 'configurationDone'
        }
        self._send_request(configurationDone)

        # 3]  Waits for configurationDone response
        await self._wait_for_response()

        # 4] Waits for attachResponse and returns it
        attach_rep = await self._wait_for_response()
        return attach_rep

    def get_host_port(self):
        if self.debugpy_port == -1:
            socket = self.debugpy_stream.socket
            socket.bind_to_random_port('tcp://' + self.debugpy_host)
            self.endpoint = socket.getsockopt(
                zmq.LAST_ENDPOINT).decode('utf-8')
            socket.unbind(self.endpoint)
            index = self.endpoint.rfind(':')
            self.debugpy_port = self.endpoint[index + 1:]
        return self.debugpy_host, self.debugpy_port

    def connect_tcp_socket(self):
        self.debugpy_stream.socket.connect(self._get_endpoint())
        self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID)

    def disconnect_tcp_socket(self):
        self.debugpy_stream.socket.disconnect(self._get_endpoint())
        self.routing_id = None
        self.init_event = Event()
        self.init_event_seq = -1
        self.wait_for_attach = True

    def receive_dap_frame(self, frame):
        self.message_queue.put_tcp_frame(frame)

    async def send_dap_request(self, msg):
        self._send_request(msg)
        if self.wait_for_attach and msg['command'] == 'attach':
            rep = await self._handle_init_sequence()
            self.wait_for_attach = False
            return rep
        else:
            rep = await self._wait_for_response()
            self.log.debug('DEBUGPYCLIENT - returning:')
            self.log.debug(rep)
            return rep
Пример #49
0
class Queue(object):
    """Coordinate producer and consumer coroutines.

    If maxsize is 0 (the default) the queue size is unbounded.

    .. testcode::

        from tornado import gen
        from tornado.ioloop import IOLoop
        from tornado.queues import Queue

        q = Queue(maxsize=2)

        @gen.coroutine
        def consumer():
            while True:
                item = yield q.get()
                try:
                    print('Doing work on %s' % item)
                    yield gen.sleep(0.01)
                finally:
                    q.task_done()

        @gen.coroutine
        def producer():
            for item in range(5):
                yield q.put(item)
                print('Put %s' % item)

        @gen.coroutine
        def main():
            # Start consumer without waiting (since it never finishes).
            IOLoop.current().spawn_callback(consumer)
            yield producer()     # Wait for producer to put all tasks.
            yield q.join()       # Wait for consumer to finish all tasks.
            print('Done')

        IOLoop.current().run_sync(main)

    .. testoutput::

        Put 0
        Put 1
        Doing work on 0
        Put 2
        Doing work on 1
        Put 3
        Doing work on 2
        Put 4
        Doing work on 3
        Doing work on 4
        Done

    In Python 3.5, `Queue` implements the async iterator protocol, so
    ``consumer()`` could be rewritten as::

        async def consumer():
            async for item in q:
                try:
                    print('Doing work on %s' % item)
                    yield gen.sleep(0.01)
                finally:
                    q.task_done()

    .. versionchanged:: 4.3
       Added ``async for`` support in Python 3.5.

    """
    def __init__(self, maxsize=0):
        if maxsize is None:
            raise TypeError("maxsize can't be None")

        if maxsize < 0:
            raise ValueError("maxsize can't be negative")

        self._maxsize = maxsize
        self._init()
        self._getters = collections.deque([])  # Futures.
        self._putters = collections.deque([])  # Pairs of (item, Future).
        self._unfinished_tasks = 0
        self._finished = Event()
        self._finished.set()

    @property
    def maxsize(self):
        """Number of items allowed in the queue."""
        return self._maxsize

    def qsize(self):
        """Number of items in the queue."""
        return len(self._queue)

    def empty(self):
        return not self._queue

    def full(self):
        if self.maxsize == 0:
            return False
        else:
            return self.qsize() >= self.maxsize

    def put(self, item, timeout=None):
        """Put an item into the queue, perhaps waiting until there is room.

        Returns a Future, which raises `tornado.gen.TimeoutError` after a
        timeout.
        """
        try:
            self.put_nowait(item)
        except QueueFull:
            future = Future()
            self._putters.append((item, future))
            _set_timeout(future, timeout)
            return future
        else:
            return gen._null_future

    def put_nowait(self, item):
        """Put an item into the queue without blocking.

        If no free slot is immediately available, raise `QueueFull`.
        """
        self._consume_expired()
        if self._getters:
            assert self.empty(), "queue non-empty, why are getters waiting?"
            getter = self._getters.popleft()
            self.__put_internal(item)
            getter.set_result(self._get())
        elif self.full():
            raise QueueFull
        else:
            self.__put_internal(item)

    def get(self, timeout=None):
        """Remove and return an item from the queue.

        Returns a Future which resolves once an item is available, or raises
        `tornado.gen.TimeoutError` after a timeout.
        """
        future = Future()
        try:
            future.set_result(self.get_nowait())
        except QueueEmpty:
            self._getters.append(future)
            _set_timeout(future, timeout)
        return future

    def get_nowait(self):
        """Remove and return an item from the queue without blocking.

        Return an item if one is immediately available, else raise
        `QueueEmpty`.
        """
        self._consume_expired()
        if self._putters:
            assert self.full(), "queue not full, why are putters waiting?"
            item, putter = self._putters.popleft()
            self.__put_internal(item)
            putter.set_result(None)
            return self._get()
        elif self.qsize():
            return self._get()
        else:
            raise QueueEmpty

    def task_done(self):
        """Indicate that a formerly enqueued task is complete.

        Used by queue consumers. For each `.get` used to fetch a task, a
        subsequent call to `.task_done` tells the queue that the processing
        on the task is complete.

        If a `.join` is blocking, it resumes when all items have been
        processed; that is, when every `.put` is matched by a `.task_done`.

        Raises `ValueError` if called more times than `.put`.
        """
        if self._unfinished_tasks <= 0:
            raise ValueError('task_done() called too many times')
        self._unfinished_tasks -= 1
        if self._unfinished_tasks == 0:
            self._finished.set()

    def join(self, timeout=None):
        """Block until all items in the queue are processed.

        Returns a Future, which raises `tornado.gen.TimeoutError` after a
        timeout.
        """
        return self._finished.wait(timeout)

    @gen.coroutine
    def __aiter__(self):
        return _QueueIterator(self)

    # These three are overridable in subclasses.
    def _init(self):
        self._queue = collections.deque()

    def _get(self):
        return self._queue.popleft()

    def _put(self, item):
        self._queue.append(item)

    # End of the overridable methods.

    def __put_internal(self, item):
        self._unfinished_tasks += 1
        self._finished.clear()
        self._put(item)

    def _consume_expired(self):
        # Remove timed-out waiters.
        while self._putters and self._putters[0][1].done():
            self._putters.popleft()

        while self._getters and self._getters[0].done():
            self._getters.popleft()

    def __repr__(self):
        return '<%s at %s %s>' % (type(self).__name__, hex(
            id(self)), self._format())

    def __str__(self):
        return '<%s %s>' % (type(self).__name__, self._format())

    def _format(self):
        result = 'maxsize=%r' % (self.maxsize, )
        if getattr(self, '_queue', None):
            result += ' queue=%r' % self._queue
        if self._getters:
            result += ' getters[%s]' % len(self._getters)
        if self._putters:
            result += ' putters[%s]' % len(self._putters)
        if self._unfinished_tasks:
            result += ' tasks=%s' % self._unfinished_tasks
        return result
Пример #50
0
 def disconnect_tcp_socket(self):
     self.debugpy_stream.socket.disconnect(self._get_endpoint())
     self.routing_id = None
     self.init_event = Event()
     self.init_event_seq = -1
     self.wait_for_attach = True
Пример #51
0
def test_exit_callback():
    to_child = mp_context.Queue()
    from_child = mp_context.Queue()
    evt = Event()

    @gen.coroutine
    def on_stop(_proc):
        assert _proc is proc
        yield gen.moment
        evt.set()

    # Normal process exit
    proc = AsyncProcess(target=feed, args=(to_child, from_child))
    evt.clear()
    proc.set_exit_callback(on_stop)
    proc.daemon = True

    yield proc.start()
    yield gen.sleep(0.05)
    assert proc.is_alive()
    assert not evt.is_set()

    to_child.put(None)
    yield evt.wait(timedelta(seconds=3))
    assert evt.is_set()
    assert not proc.is_alive()

    # Process terminated
    proc = AsyncProcess(target=wait)
    evt.clear()
    proc.set_exit_callback(on_stop)
    proc.daemon = True

    yield proc.start()
    yield gen.sleep(0.05)
    assert proc.is_alive()
    assert not evt.is_set()

    yield proc.terminate()
    yield evt.wait(timedelta(seconds=3))
    assert evt.is_set()
Пример #52
0
 def get(self):
     never_finish = Event()
     yield never_finish.wait()
Пример #53
0
class WebUpdater:
    def __init__(self, umgr, config):
        self.umgr = umgr
        self.server = umgr.server
        self.notify_update_response = umgr.notify_update_response
        self.repo = config.get('repo').strip().strip("/")
        self.name = self.repo.split("/")[-1]
        if hasattr(config, "get_name"):
            self.name = config.get_name().split()[-1]
        self.path = os.path.realpath(os.path.expanduser(config.get("path")))
        self.persistent_files = []
        pfiles = config.get('persistent_files', None)
        if pfiles is not None:
            self.persistent_files = [
                pf.strip().strip("/") for pf in pfiles.split("\n")
                if pf.strip()
            ]
            if ".version" in self.persistent_files:
                raise config.error(
                    "Invalid value for option 'persistent_files': "
                    "'.version' can not be persistent")

        self.version = self.remote_version = self.dl_url = "?"
        self.etag = None
        self.init_evt = Event()
        self.refresh_condition = None
        self._get_local_version()
        logging.info(f"\nInitializing Client Updater: '{self.name}',"
                     f"\nversion: {self.version}"
                     f"\npath: {self.path}")

    def _get_local_version(self):
        version_path = os.path.join(self.path, ".version")
        if os.path.isfile(os.path.join(self.path, ".version")):
            with open(version_path, "r") as f:
                v = f.read()
            self.version = v.strip()

    async def check_initialized(self, timeout=None):
        if self.init_evt.is_set():
            return
        if timeout is not None:
            timeout = IOLoop.current().time() + timeout
        await self.init_evt.wait(timeout)

    async def refresh(self):
        if self.refresh_condition is None:
            self.refresh_condition = Condition()
        else:
            self.refresh_condition.wait()
            return
        try:
            self._get_local_version()
            await self._get_remote_version()
        except Exception:
            logging.exception("Error Refreshing Client")
        self.init_evt.set()
        self.refresh_condition.notify_all()
        self.refresh_condition = None

    async def _get_remote_version(self):
        # Remote state
        url = f"https://api.github.com/repos/{self.repo}/releases/latest"
        try:
            result = await self.umgr.github_api_request(url, etag=self.etag)
        except Exception:
            logging.exception(f"Client {self.repo}: Github Request Error")
            result = {}
        if result is None:
            # No change, update not necessary
            return
        self.etag = result.get('etag', None)
        self.remote_version = result.get('name', "?")
        release_assets = result.get('assets', [{}])[0]
        self.dl_url = release_assets.get('browser_download_url', "?")
        logging.info(f"Github client Info Received:\nRepo: {self.name}\n"
                     f"Local Version: {self.version}\n"
                     f"Remote Version: {self.remote_version}\n"
                     f"url: {self.dl_url}")

    async def update(self, *args):
        await self.check_initialized(20.)
        if self.refresh_condition is not None:
            # wait for refresh if in progess
            self.refresh_condition.wait()
        if self.remote_version == "?":
            await self.refresh()
            if self.remote_version == "?":
                raise self.server.error(
                    f"Client {self.repo}: Unable to locate update")
        if self.dl_url == "?":
            raise self.server.error(
                f"Client {self.repo}: Invalid download url")
        if self.version == self.remote_version:
            # Already up to date
            return
        with tempfile.TemporaryDirectory(suffix=self.name,
                                         prefix="client") as tempdir:
            if os.path.isdir(self.path):
                # find and move persistent files
                for fname in os.listdir(self.path):
                    src_path = os.path.join(self.path, fname)
                    if fname in self.persistent_files:
                        dest_dir = os.path.dirname(os.path.join(
                            tempdir, fname))
                        os.makedirs(dest_dir, exist_ok=True)
                        shutil.move(src_path, dest_dir)
                shutil.rmtree(self.path)
            os.mkdir(self.path)
            self.notify_update_response(f"Downloading Client: {self.name}")
            archive = await self.umgr.http_download_request(self.dl_url)
            with zipfile.ZipFile(io.BytesIO(archive)) as zf:
                zf.extractall(self.path)
            # Move temporary files back into
            for fname in os.listdir(tempdir):
                src_path = os.path.join(tempdir, fname)
                dest_dir = os.path.dirname(os.path.join(self.path, fname))
                os.makedirs(dest_dir, exist_ok=True)
                shutil.move(src_path, dest_dir)
        self.version = self.remote_version
        version_path = os.path.join(self.path, ".version")
        if not os.path.exists(version_path):
            with open(version_path, "w") as f:
                f.write(self.version)
        self.notify_update_response(f"Client Update Finished: {self.name}",
                                    is_complete=True)

    def get_update_status(self):
        return {
            'name': self.name,
            'version': self.version,
            'remote_version': self.remote_version
        }
Пример #54
0
    def __init__(
        self,
        handlers,
        blocked_handlers=None,
        stream_handlers=None,
        connection_limit=512,
        deserialize=True,
        io_loop=None,
    ):
        self.handlers = {
            "identity": self.identity,
            "connection_stream": self.handle_stream,
        }
        self.handlers.update(handlers)
        if blocked_handlers is None:
            blocked_handlers = dask.config.get(
                "distributed.%s.blocked-handlers" %
                type(self).__name__.lower(), [])
        self.blocked_handlers = blocked_handlers
        self.stream_handlers = {}
        self.stream_handlers.update(stream_handlers or {})

        self.id = type(self).__name__ + "-" + str(uuid.uuid4())
        self._address = None
        self._listen_address = None
        self._port = None
        self._comms = {}
        self.deserialize = deserialize
        self.monitor = SystemMonitor()
        self.counters = None
        self.digests = None
        self.events = None
        self.event_counts = None
        self._ongoing_coroutines = weakref.WeakSet()
        self._event_finished = Event()

        self.listeners = []
        self.io_loop = io_loop or IOLoop.current()
        self.loop = self.io_loop

        if not hasattr(self.io_loop, "profile"):
            ref = weakref.ref(self.io_loop)

            if hasattr(self.io_loop, "asyncio_loop"):

                def stop():
                    loop = ref()
                    return loop is None or loop.asyncio_loop.is_closed()

            else:

                def stop():
                    loop = ref()
                    return loop is None or loop._closing

            self.io_loop.profile = profile.watch(
                omit=("profile.py", "selectors.py"),
                interval=dask.config.get(
                    "distributed.worker.profile.interval"),
                cycle=dask.config.get("distributed.worker.profile.cycle"),
                stop=stop,
            )

        # Statistics counters for various events
        with ignoring(ImportError):
            from .counter import Digest

            self.digests = defaultdict(partial(Digest, loop=self.io_loop))

        from .counter import Counter

        self.counters = defaultdict(partial(Counter, loop=self.io_loop))
        self.events = defaultdict(lambda: deque(maxlen=10000))
        self.event_counts = defaultdict(lambda: 0)

        self.periodic_callbacks = dict()

        pc = PeriodicCallback(self.monitor.update, 500, io_loop=self.io_loop)
        self.periodic_callbacks["monitor"] = pc

        self._last_tick = time()
        pc = PeriodicCallback(
            self._measure_tick,
            parse_timedelta(dask.config.get("distributed.admin.tick.interval"),
                            default="ms") * 1000,
            io_loop=self.io_loop,
        )
        self.periodic_callbacks["tick"] = pc

        self.thread_id = 0

        def set_thread_ident():
            self.thread_id = threading.get_ident()

        self.io_loop.add_callback(set_thread_ident)

        self.__stopped = False
Пример #55
0
    def __init__(self, config):
        self.server = config.get_server()
        self.config = config
        self.config.read_supplemental_config(SUPPLEMENTAL_CFG_PATH)
        self.repo_debug = config.getboolean('enable_repo_debug', False)
        auto_refresh_enabled = config.getboolean('enable_auto_refresh', False)
        self.distro = config.get('distro', "debian").lower()
        if self.distro not in SUPPORTED_DISTROS:
            raise config.error(f"Unsupported distro: {self.distro}")
        if self.repo_debug:
            logging.warn("UPDATE MANAGER: REPO DEBUG ENABLED")
        env = sys.executable
        mooncfg = self.config[f"update_manager static {self.distro} moonraker"]
        self.updaters = {
            "system": PackageUpdater(self),
            "moonraker": GitUpdater(self, mooncfg, MOONRAKER_PATH, env)
        }
        self.current_update = None
        # TODO: Check for client config in [update_manager].  This is
        # deprecated and will be removed.
        client_repo = config.get("client_repo", None)
        if client_repo is not None:
            client_path = config.get("client_path")
            name = client_repo.split("/")[-1]
            self.updaters[name] = WebUpdater(self, {
                'repo': client_repo,
                'path': client_path
            })
        client_sections = self.config.get_prefix_sections(
            "update_manager client")
        for section in client_sections:
            cfg = self.config[section]
            name = section.split()[-1]
            if name in self.updaters:
                raise config.error("Client repo named %s already added" %
                                   (name, ))
            client_type = cfg.get("type")
            if client_type == "git_repo":
                self.updaters[name] = GitUpdater(self, cfg)
            elif client_type == "web":
                self.updaters[name] = WebUpdater(self, cfg)
            else:
                raise config.error("Invalid type '%s' for section [%s]" %
                                   (client_type, section))

        # GitHub API Rate Limit Tracking
        self.gh_rate_limit = None
        self.gh_limit_remaining = None
        self.gh_limit_reset_time = None
        self.gh_init_evt = Event()
        self.cmd_request_lock = Lock()
        self.is_refreshing = False

        # Auto Status Refresh
        self.last_auto_update_time = 0
        self.refresh_cb = None
        if auto_refresh_enabled:
            self.refresh_cb = PeriodicCallback(self._handle_auto_refresh,
                                               UPDATE_REFRESH_INTERVAL_MS)
            self.refresh_cb.start()

        AsyncHTTPClient.configure(None, defaults=dict(user_agent="Moonraker"))
        self.http_client = AsyncHTTPClient()

        self.server.register_endpoint("/machine/update/moonraker", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/klipper", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/system", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/client", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/status", ["GET"],
                                      self._handle_status_request)
        self.server.register_notification("update_manager:update_response")
        self.server.register_notification("update_manager:update_refreshed")

        # Register Ready Event
        self.server.register_event_handler("server:klippy_identified",
                                           self._set_klipper_repo)
        # Initialize GitHub API Rate Limits and configured updaters
        IOLoop.current().spawn_callback(self._initalize_updaters,
                                        list(self.updaters.values()))
Пример #56
0
 def __init__(self, rpc_method, params):
     self.id = id(self)
     self.rpc_method = rpc_method
     self.params = params
     self._event = Event()
     self.response = None
Пример #57
0
class PollingLock(object):
    """ Acquires a lock by writing to a key. This is suitable for a leader
      election in cases where some downtime and initial acquisition delay is
      acceptable.

      Unlike ZooKeeper and etcd, FoundationDB does not have a way
      to specify that a key should be automatically deleted if a client does
      not heartbeat at a regular interval. This implementation requires the
      leader to update the key at regular intervals to indicate that it is
      still alive. All the other lock candidates check at a longer interval to
      see if the leader has stopped updating the key.

      Since client timestamps are unreliable, candidates do not know the
      absolute time the key was updated. Therefore, they each wait for the full
      timeout interval before checking the key again.
  """
    # The number of seconds to wait before trying to claim the lease.
    _LEASE_TIMEOUT = 60

    # The number of seconds to wait before updating the lease.
    _HEARTBEAT_INTERVAL = int(_LEASE_TIMEOUT / 10)

    def __init__(self, db, tornado_fdb, key):
        self.key = key
        self._db = db
        self._tornado_fdb = tornado_fdb

        self._client_id = uuid.uuid4()
        self._owner = None
        self._op_id = None
        self._deadline = None
        self._event = Event()

    @property
    def acquired(self):
        if self._deadline is None:
            return False

        return (self._owner == self._client_id
                and monotonic.monotonic() < self._deadline)

    def start(self):
        IOLoop.current().spawn_callback(self._run)

    @gen.coroutine
    def acquire(self):
        # Since there is no automatic event timeout, the condition is checked
        # before every acquisition.
        if not self.acquired:
            self._event.clear()

        yield self._event.wait()

    @gen.coroutine
    def _run(self):
        while True:
            try:
                yield self._acquire_lease()
            except Exception:
                logger.exception(u'Unable to acquire lease')
                yield gen.sleep(random.random() * 20)

    @gen.coroutine
    def _acquire_lease(self):
        tr = self._db.create_transaction()
        lease_value = yield self._tornado_fdb.get(tr, self.key)

        if lease_value.present():
            self._owner, new_op_id = fdb.tuple.unpack(lease_value)
            if new_op_id != self._op_id:
                self._deadline = monotonic.monotonic() + self._LEASE_TIMEOUT
                self._op_id = new_op_id
        else:
            self._owner = None

        can_acquire = self._owner is None or monotonic.monotonic(
        ) > self._deadline
        if can_acquire or self._owner == self._client_id:
            op_id = uuid.uuid4()
            tr[self.key] = fdb.tuple.pack((self._client_id, op_id))
            try:
                yield self._tornado_fdb.commit(tr, convert_exceptions=False)
            except fdb.FDBError as fdb_error:
                if fdb_error.code != FDBErrorCodes.NOT_COMMITTED:
                    raise

                # If there was a conflict, try to acquire again later.
                yield gen.sleep(random.random() * 20)
                return

            self._owner = self._client_id
            self._op_id = op_id
            self._deadline = monotonic.monotonic() + self._LEASE_TIMEOUT
            self._event.set()
            if can_acquire:
                logger.info(u'Acquired lock for {!r}'.format(self.key))

            yield gen.sleep(self._HEARTBEAT_INTERVAL)
            return

        # Since another candidate holds the lock, wait until it might expire.
        yield gen.sleep(max(self._deadline - monotonic.monotonic(), 0))
Пример #58
0
class WebSocketTest(WebSocketBaseTestCase):
    def get_app(self):
        self.close_future = Future()  # type: Future[None]
        return Application(
            [
                ("/echo", EchoHandler, dict(close_future=self.close_future)),
                ("/non_ws", NonWebSocketHandler),
                ("/header", HeaderHandler, dict(close_future=self.close_future)),
                (
                    "/header_echo",
                    HeaderEchoHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/close_reason",
                    CloseReasonHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/error_in_on_message",
                    ErrorInOnMessageHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/async_prepare",
                    AsyncPrepareHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/path_args/(.*)",
                    PathArgsHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/coroutine",
                    CoroutineOnMessageHandler,
                    dict(close_future=self.close_future),
                ),
                ("/render", RenderMessageHandler, dict(close_future=self.close_future)),
                (
                    "/subprotocol",
                    SubprotocolHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/open_coroutine",
                    OpenCoroutineHandler,
                    dict(close_future=self.close_future, test=self),
                ),
            ],
            template_loader=DictLoader({"message.html": "<b>{{ message }}</b>"}),
        )

    def get_http_client(self):
        # These tests require HTTP/1; force the use of SimpleAsyncHTTPClient.
        return SimpleAsyncHTTPClient()

    def tearDown(self):
        super(WebSocketTest, self).tearDown()
        RequestHandler._template_loaders.clear()

    def test_http_request(self):
        # WS server, HTTP client.
        response = self.fetch("/echo")
        self.assertEqual(response.code, 400)

    def test_missing_websocket_key(self):
        response = self.fetch(
            "/echo",
            headers={
                "Connection": "Upgrade",
                "Upgrade": "WebSocket",
                "Sec-WebSocket-Version": "13",
            },
        )
        self.assertEqual(response.code, 400)

    def test_bad_websocket_version(self):
        response = self.fetch(
            "/echo",
            headers={
                "Connection": "Upgrade",
                "Upgrade": "WebSocket",
                "Sec-WebSocket-Version": "12",
            },
        )
        self.assertEqual(response.code, 426)

    @gen_test
    def test_websocket_gen(self):
        ws = yield self.ws_connect("/echo")
        yield ws.write_message("hello")
        response = yield ws.read_message()
        self.assertEqual(response, "hello")

    def test_websocket_callbacks(self):
        websocket_connect(
            "ws://127.0.0.1:%d/echo" % self.get_http_port(), callback=self.stop
        )
        ws = self.wait().result()
        ws.write_message("hello")
        ws.read_message(self.stop)
        response = self.wait().result()
        self.assertEqual(response, "hello")
        self.close_future.add_done_callback(lambda f: self.stop())
        ws.close()
        self.wait()

    @gen_test
    def test_binary_message(self):
        ws = yield self.ws_connect("/echo")
        ws.write_message(b"hello \xe9", binary=True)
        response = yield ws.read_message()
        self.assertEqual(response, b"hello \xe9")

    @gen_test
    def test_unicode_message(self):
        ws = yield self.ws_connect("/echo")
        ws.write_message(u"hello \u00e9")
        response = yield ws.read_message()
        self.assertEqual(response, u"hello \u00e9")

    @gen_test
    def test_render_message(self):
        ws = yield self.ws_connect("/render")
        ws.write_message("hello")
        response = yield ws.read_message()
        self.assertEqual(response, "<b>hello</b>")

    @gen_test
    def test_error_in_on_message(self):
        ws = yield self.ws_connect("/error_in_on_message")
        ws.write_message("hello")
        with ExpectLog(app_log, "Uncaught exception"):
            response = yield ws.read_message()
        self.assertIs(response, None)

    @gen_test
    def test_websocket_http_fail(self):
        with self.assertRaises(HTTPError) as cm:
            yield self.ws_connect("/notfound")
        self.assertEqual(cm.exception.code, 404)

    @gen_test
    def test_websocket_http_success(self):
        with self.assertRaises(WebSocketError):
            yield self.ws_connect("/non_ws")

    @gen_test
    def test_websocket_network_fail(self):
        sock, port = bind_unused_port()
        sock.close()
        with self.assertRaises(IOError):
            with ExpectLog(gen_log, ".*"):
                yield websocket_connect(
                    "ws://127.0.0.1:%d/" % port, connect_timeout=3600
                )

    @gen_test
    def test_websocket_close_buffered_data(self):
        ws = yield websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port())
        ws.write_message("hello")
        ws.write_message("world")
        # Close the underlying stream.
        ws.stream.close()

    @gen_test
    def test_websocket_headers(self):
        # Ensure that arbitrary headers can be passed through websocket_connect.
        ws = yield websocket_connect(
            HTTPRequest(
                "ws://127.0.0.1:%d/header" % self.get_http_port(),
                headers={"X-Test": "hello"},
            )
        )
        response = yield ws.read_message()
        self.assertEqual(response, "hello")

    @gen_test
    def test_websocket_header_echo(self):
        # Ensure that headers can be returned in the response.
        # Specifically, that arbitrary headers passed through websocket_connect
        # can be returned.
        ws = yield websocket_connect(
            HTTPRequest(
                "ws://127.0.0.1:%d/header_echo" % self.get_http_port(),
                headers={"X-Test-Hello": "hello"},
            )
        )
        self.assertEqual(ws.headers.get("X-Test-Hello"), "hello")
        self.assertEqual(
            ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value"
        )

    @gen_test
    def test_server_close_reason(self):
        ws = yield self.ws_connect("/close_reason")
        msg = yield ws.read_message()
        # A message of None means the other side closed the connection.
        self.assertIs(msg, None)
        self.assertEqual(ws.close_code, 1001)
        self.assertEqual(ws.close_reason, "goodbye")
        # The on_close callback is called no matter which side closed.
        code, reason = yield self.close_future
        # The client echoed the close code it received to the server,
        # so the server's close code (returned via close_future) is
        # the same.
        self.assertEqual(code, 1001)

    @gen_test
    def test_client_close_reason(self):
        ws = yield self.ws_connect("/echo")
        ws.close(1001, "goodbye")
        code, reason = yield self.close_future
        self.assertEqual(code, 1001)
        self.assertEqual(reason, "goodbye")

    @gen_test
    def test_write_after_close(self):
        ws = yield self.ws_connect("/close_reason")
        msg = yield ws.read_message()
        self.assertIs(msg, None)
        with self.assertRaises(WebSocketClosedError):
            ws.write_message("hello")

    @gen_test
    def test_async_prepare(self):
        # Previously, an async prepare method triggered a bug that would
        # result in a timeout on test shutdown (and a memory leak).
        ws = yield self.ws_connect("/async_prepare")
        ws.write_message("hello")
        res = yield ws.read_message()
        self.assertEqual(res, "hello")

    @gen_test
    def test_path_args(self):
        ws = yield self.ws_connect("/path_args/hello")
        res = yield ws.read_message()
        self.assertEqual(res, "hello")

    @gen_test
    def test_coroutine(self):
        ws = yield self.ws_connect("/coroutine")
        # Send both messages immediately, coroutine must process one at a time.
        yield ws.write_message("hello1")
        yield ws.write_message("hello2")
        res = yield ws.read_message()
        self.assertEqual(res, "hello1")
        res = yield ws.read_message()
        self.assertEqual(res, "hello2")

    @gen_test
    def test_check_origin_valid_no_path(self):
        port = self.get_http_port()

        url = "ws://127.0.0.1:%d/echo" % port
        headers = {"Origin": "http://127.0.0.1:%d" % port}

        ws = yield websocket_connect(HTTPRequest(url, headers=headers))
        ws.write_message("hello")
        response = yield ws.read_message()
        self.assertEqual(response, "hello")

    @gen_test
    def test_check_origin_valid_with_path(self):
        port = self.get_http_port()

        url = "ws://127.0.0.1:%d/echo" % port
        headers = {"Origin": "http://127.0.0.1:%d/something" % port}

        ws = yield websocket_connect(HTTPRequest(url, headers=headers))
        ws.write_message("hello")
        response = yield ws.read_message()
        self.assertEqual(response, "hello")

    @gen_test
    def test_check_origin_invalid_partial_url(self):
        port = self.get_http_port()

        url = "ws://127.0.0.1:%d/echo" % port
        headers = {"Origin": "127.0.0.1:%d" % port}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers))
        self.assertEqual(cm.exception.code, 403)

    @gen_test
    def test_check_origin_invalid(self):
        port = self.get_http_port()

        url = "ws://127.0.0.1:%d/echo" % port
        # Host is 127.0.0.1, which should not be accessible from some other
        # domain
        headers = {"Origin": "http://somewhereelse.com"}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers))

        self.assertEqual(cm.exception.code, 403)

    @gen_test
    def test_check_origin_invalid_subdomains(self):
        port = self.get_http_port()

        url = "ws://localhost:%d/echo" % port
        # Subdomains should be disallowed by default.  If we could pass a
        # resolver to websocket_connect we could test sibling domains as well.
        headers = {"Origin": "http://subtenant.localhost"}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers))

        self.assertEqual(cm.exception.code, 403)

    @gen_test
    def test_subprotocols(self):
        ws = yield self.ws_connect(
            "/subprotocol", subprotocols=["badproto", "goodproto"]
        )
        self.assertEqual(ws.selected_subprotocol, "goodproto")
        res = yield ws.read_message()
        self.assertEqual(res, "subprotocol=goodproto")

    @gen_test
    def test_subprotocols_not_offered(self):
        ws = yield self.ws_connect("/subprotocol")
        self.assertIs(ws.selected_subprotocol, None)
        res = yield ws.read_message()
        self.assertEqual(res, "subprotocol=None")

    @gen_test
    def test_open_coroutine(self):
        self.message_sent = Event()
        ws = yield self.ws_connect("/open_coroutine")
        yield ws.write_message("hello")
        self.message_sent.set()
        res = yield ws.read_message()
        self.assertEqual(res, "ok")
Пример #59
0
 def test_unused_connection(self):
     stream = yield self.connect()
     event = Event()
     stream.set_close_callback(event.set)
     yield event.wait()
Пример #60
0
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.locks import Event

event = Event()


async def consumer():
    print('Waiting for product')

    await event.wait()
    print('Product was found')
    event.clear()

    await event.wait()
    print('Product was found twice')

    return 1


async def producer():
    print("About to set the event")

    await gen.sleep(3)
    print('Set Event')
    event.set()

    await gen.sleep(3)
    print('Set Event')
    event.set()