Пример #1
0
class ProtoIoGuarder:
    def __init__(self, proto_io):
        self.lock = Lock()
        self.proto_io = proto_io

    @gen.coroutine
    def get(self):
        yield self.lock.acquire()
        raise gen.Return(self.proto_io)

    def release(self):
        self.lock.release()
Пример #2
0
class SessionManager(dict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._lock = Lock()
        self._session_locks = {}

    @contextmanager
    def get(self, key):
        try:
            self._lock.acquire()
            if key not in self:
                self._session_locks[key] = Lock()
                self[key] = {}
        finally:
            self._lock.release()

        try:
            self._session_locks[key].acquire()
            yield self[key]
        finally:
            self._session_locks[key].release()
Пример #3
0
class IPCMessageSubscriber(IPCClient):
    '''
    Salt IPC message subscriber

    Create an IPC client to receive messages from IPC publisher

    An example of a very simple IPCMessageSubscriber connecting to an IPCMessagePublisher.
    This example assumes an already running IPCMessagePublisher.

    IMPORTANT: The below example also assumes the IOLoop is NOT running.

    # Import Tornado libs
    import tornado.ioloop

    # Import Salt libs
    import salt.config
    import salt.transport.ipc

    # Create a new IO Loop.
    # We know that this new IO Loop is not currently running.
    io_loop = tornado.ioloop.IOLoop()

    ipc_publisher_socket_path = '/var/run/ipc_publisher.ipc'

    ipc_subscriber = salt.transport.ipc.IPCMessageSubscriber(ipc_server_socket_path, io_loop=io_loop)

    # Connect to the server
    # Use the associated IO Loop that isn't running.
    io_loop.run_sync(ipc_subscriber.connect)

    # Wait for some data
    package = ipc_subscriber.read_sync()
    '''
    def __init__(self, socket_path, io_loop=None):
        super(IPCMessageSubscriber, self).__init__(socket_path,
                                                   io_loop=io_loop)
        self._read_stream_future = None
        self._saved_data = []
        self._read_in_progress = Lock()
        self.callbacks = set()

    @tornado.gen.coroutine
    def _read(self, timeout, callback=None):
        try:
            yield self._read_in_progress.acquire(timeout=0.00000001)
        except tornado.gen.TimeoutError:
            raise tornado.gen.Return(None)

        log.debug('IPC Subscriber is starting reading')
        exc_to_raise = None
        ret = None
        try:
            while True:
                if self._read_stream_future is None:
                    self._read_stream_future = self.stream.read_bytes(
                        4096, partial=True)

                if timeout is None:
                    wire_bytes = yield self._read_stream_future
                else:
                    wire_bytes = yield FutureWithTimeout(
                        self.io_loop, self._read_stream_future, timeout)
                self._read_stream_future = None

                # Remove the timeout once we get some data or an exception
                # occurs. We will assume that the rest of the data is already
                # there or is coming soon if an exception doesn't occur.
                timeout = None

                self.unpacker.feed(wire_bytes)
                first_sync_msg = True
                for framed_msg in self.unpacker:
                    if callback:
                        self.io_loop.spawn_callback(callback,
                                                    framed_msg['body'])
                    elif first_sync_msg:
                        ret = framed_msg['body']
                        first_sync_msg = False
                    else:
                        self._saved_data.append(framed_msg['body'])
                if not first_sync_msg:
                    # We read at least one piece of data and we're on sync run
                    break
        except TornadoTimeoutError:
            # In the timeout case, just return None.
            # Keep 'self._read_stream_future' alive.
            ret = None
        except StreamClosedError as exc:
            log.trace('Subscriber disconnected from IPC %s', self.socket_path)
            self._read_stream_future = None
            exc_to_raise = exc
        except Exception as exc:
            log.error(
                'Exception occurred in Subscriber while handling stream: %s',
                exc)
            self._read_stream_future = None
            exc_to_raise = exc

        self._read_in_progress.release()

        if exc_to_raise is not None:
            raise exc_to_raise  # pylint: disable=E0702
        raise tornado.gen.Return(ret)

    def read_sync(self, timeout=None):
        '''
        Read a message from an IPC socket

        The socket must already be connected.
        The associated IO Loop must NOT be running.
        :param int timeout: Timeout when receiving message
        :return: message data if successful. None if timed out. Will raise an
                 exception for all other error conditions.
        '''
        if self._saved_data:
            return self._saved_data.pop(0)
        return self.io_loop.run_sync(lambda: self._read(timeout))

    def __run_callbacks(self, raw):
        for callback in self.callbacks:
            self.io_loop.spawn_callback(callback, raw)

    @tornado.gen.coroutine
    def read_async(self):
        '''
        Asynchronously read messages and invoke a callback when they are ready.

        :param callback: A callback with the received data
        '''
        while not self.connected():
            try:
                yield self.connect(timeout=5)
            except StreamClosedError:
                log.trace('Subscriber closed stream on IPC %s before connect',
                          self.socket_path)
                yield tornado.gen.sleep(1)
            except Exception as exc:
                log.error('Exception occurred while Subscriber connecting: %s',
                          exc)
                yield tornado.gen.sleep(1)
        yield self._read(None, self.__run_callbacks)

    def close(self):
        '''
        Routines to handle any cleanup before the instance shuts down.
        Sockets and filehandles should be closed explicitly, to prevent
        leaks.
        '''
        if self._closing:
            return
        super(IPCMessageSubscriber, self).close()
        # This will prevent this message from showing up:
        # '[ERROR   ] Future exception was never retrieved:
        # StreamClosedError'
        if self._read_stream_future is not None and self._read_stream_future.done(
        ):
            exc = self._read_stream_future.exception()
            if exc and not isinstance(exc, StreamClosedError):
                log.error("Read future returned exception %r", exc)
Пример #4
0
class UpdateManager:
    def __init__(self, config):
        self.server = config.get_server()
        self.config = config
        self.config.read_supplemental_config(SUPPLEMENTAL_CFG_PATH)
        self.repo_debug = config.getboolean('enable_repo_debug', False)
        auto_refresh_enabled = config.getboolean('enable_auto_refresh', False)
        self.distro = config.get('distro', "debian").lower()
        if self.distro not in SUPPORTED_DISTROS:
            raise config.error(f"Unsupported distro: {self.distro}")
        if self.repo_debug:
            logging.warn("UPDATE MANAGER: REPO DEBUG ENABLED")
        env = sys.executable
        mooncfg = self.config[f"update_manager static {self.distro} moonraker"]
        self.updaters = {
            "system": PackageUpdater(self),
            "moonraker": GitUpdater(self, mooncfg, MOONRAKER_PATH, env)
        }
        self.current_update = None
        # TODO: Check for client config in [update_manager].  This is
        # deprecated and will be removed.
        client_repo = config.get("client_repo", None)
        if client_repo is not None:
            client_path = config.get("client_path")
            name = client_repo.split("/")[-1]
            self.updaters[name] = WebUpdater(self, {
                'repo': client_repo,
                'path': client_path
            })
        client_sections = self.config.get_prefix_sections(
            "update_manager client")
        for section in client_sections:
            cfg = self.config[section]
            name = section.split()[-1]
            if name in self.updaters:
                raise config.error("Client repo named %s already added" %
                                   (name, ))
            client_type = cfg.get("type")
            if client_type == "git_repo":
                self.updaters[name] = GitUpdater(self, cfg)
            elif client_type == "web":
                self.updaters[name] = WebUpdater(self, cfg)
            else:
                raise config.error("Invalid type '%s' for section [%s]" %
                                   (client_type, section))

        # GitHub API Rate Limit Tracking
        self.gh_rate_limit = None
        self.gh_limit_remaining = None
        self.gh_limit_reset_time = None
        self.gh_init_evt = Event()
        self.cmd_request_lock = Lock()
        self.is_refreshing = False

        # Auto Status Refresh
        self.last_auto_update_time = 0
        self.refresh_cb = None
        if auto_refresh_enabled:
            self.refresh_cb = PeriodicCallback(self._handle_auto_refresh,
                                               UPDATE_REFRESH_INTERVAL_MS)
            self.refresh_cb.start()

        AsyncHTTPClient.configure(None, defaults=dict(user_agent="Moonraker"))
        self.http_client = AsyncHTTPClient()

        self.server.register_endpoint("/machine/update/moonraker", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/klipper", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/system", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/client", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/status", ["GET"],
                                      self._handle_status_request)

        # Register Ready Event
        self.server.register_event_handler("server:klippy_identified",
                                           self._set_klipper_repo)
        # Initialize GitHub API Rate Limits and configured updaters
        IOLoop.current().spawn_callback(self._initalize_updaters,
                                        list(self.updaters.values()))

    async def _initalize_updaters(self, initial_updaters):
        self.is_refreshing = True
        await self._init_api_rate_limit()
        for updater in initial_updaters:
            if isinstance(updater, PackageUpdater):
                ret = updater.refresh(False)
            else:
                ret = updater.refresh()
            if asyncio.iscoroutine(ret):
                await ret
        self.is_refreshing = False

    async def _set_klipper_repo(self):
        kinfo = self.server.get_klippy_info()
        if not kinfo:
            logging.info("No valid klippy info received")
            return
        kpath = kinfo['klipper_path']
        env = kinfo['python_path']
        kupdater = self.updaters.get('klipper', None)
        if kupdater is not None and kupdater.repo_path == kpath and \
                kupdater.env == env:
            # Current Klipper Updater is valid
            return
        kcfg = self.config[f"update_manager static {self.distro} klipper"]
        self.updaters['klipper'] = GitUpdater(self, kcfg, kpath, env)
        await self.updaters['klipper'].refresh()

    async def _check_klippy_printing(self):
        klippy_apis = self.server.lookup_plugin('klippy_apis')
        result = await klippy_apis.query_objects({'print_stats': None},
                                                 default={})
        pstate = result.get('print_stats', {}).get('state', "")
        return pstate.lower() == "printing"

    async def _handle_auto_refresh(self):
        if await self._check_klippy_printing():
            # Don't Refresh during a print
            logging.info("Klippy is printing, auto refresh aborted")
            return
        cur_time = time.time()
        cur_hour = time.localtime(cur_time).tm_hour
        time_diff = cur_time - self.last_auto_update_time
        # Update packages if it has been more than 12 hours
        # and the local time is between 12AM and 5AM
        if time_diff < MIN_REFRESH_TIME or cur_hour >= MAX_PKG_UPDATE_HOUR:
            # Not within the update time window
            return
        self.last_auto_update_time = cur_time
        vinfo = {}
        need_refresh_all = not self.is_refreshing
        async with self.cmd_request_lock:
            self.is_refreshing = True
            try:
                for name, updater in list(self.updaters.items()):
                    if need_refresh_all:
                        ret = updater.refresh()
                        if asyncio.iscoroutine(ret):
                            await ret
                    if hasattr(updater, "get_update_status"):
                        vinfo[name] = updater.get_update_status()
            except Exception:
                logging.exception("Unable to Refresh Status")
                return
            finally:
                self.is_refreshing = False
        uinfo = {
            'version_info': vinfo,
            'github_rate_limit': self.gh_rate_limit,
            'github_requests_remaining': self.gh_limit_remaining,
            'github_limit_reset_time': self.gh_limit_reset_time,
            'busy': self.current_update is not None
        }
        self.server.send_event("update_manager:update_refreshed", uinfo)

    async def _handle_update_request(self, web_request):
        if await self._check_klippy_printing():
            raise self.server.error("Update Refused: Klippy is printing")
        app = web_request.get_endpoint().split("/")[-1]
        if app == "client":
            app = web_request.get('name')
        inc_deps = web_request.get_boolean('include_deps', False)
        if self.current_update is not None and \
                self.current_update[0] == app:
            return f"Object {app} is currently being updated"
        updater = self.updaters.get(app, None)
        if updater is None:
            raise self.server.error(f"Updater {app} not available")
        async with self.cmd_request_lock:
            self.current_update = (app, id(web_request))
            try:
                await updater.update(inc_deps)
            except Exception as e:
                self.notify_update_response(f"Error updating {app}")
                self.notify_update_response(str(e), is_complete=True)
                raise
            finally:
                self.current_update = None
        return "ok"

    async def _handle_status_request(self, web_request):
        check_refresh = web_request.get_boolean('refresh', False)
        # Don't refresh if a print is currently in progress or
        # if an update is in progress.  Just return the current
        # state
        if self.current_update is not None or \
                await self._check_klippy_printing():
            check_refresh = False
        need_refresh = False
        if check_refresh:
            # If there is an outstanding request processing a
            # refresh, we don't need to do it again.
            need_refresh = not self.is_refreshing
            await self.cmd_request_lock.acquire()
            self.is_refreshing = True
        vinfo = {}
        try:
            for name, updater in list(self.updaters.items()):
                await updater.check_initialized(120.)
                if need_refresh:
                    ret = updater.refresh()
                    if asyncio.iscoroutine(ret):
                        await ret
                if hasattr(updater, "get_update_status"):
                    vinfo[name] = updater.get_update_status()
        except Exception:
            raise
        finally:
            if check_refresh:
                self.is_refreshing = False
                self.cmd_request_lock.release()
        return {
            'version_info': vinfo,
            'github_rate_limit': self.gh_rate_limit,
            'github_requests_remaining': self.gh_limit_remaining,
            'github_limit_reset_time': self.gh_limit_reset_time,
            'busy': self.current_update is not None
        }

    async def execute_cmd(self, cmd, timeout=10., notify=False, retries=1):
        shell_command = self.server.lookup_plugin('shell_command')
        cb = self.notify_update_response if notify else None
        scmd = shell_command.build_shell_command(cmd, callback=cb)
        while retries:
            if await scmd.run(timeout=timeout, verbose=notify):
                break
            retries -= 1
        if not retries:
            raise self.server.error("Shell Command Error")

    async def execute_cmd_with_response(self, cmd, timeout=10.):
        shell_command = self.server.lookup_plugin('shell_command')
        scmd = shell_command.build_shell_command(cmd, None)
        result = await scmd.run_with_response(timeout, retries=5)
        if result is None:
            raise self.server.error(f"Error Running Command: {cmd}")
        return result

    async def _init_api_rate_limit(self):
        url = "https://api.github.com/rate_limit"
        while 1:
            try:
                resp = await self.github_api_request(url, is_init=True)
                core = resp['resources']['core']
                self.gh_rate_limit = core['limit']
                self.gh_limit_remaining = core['remaining']
                self.gh_limit_reset_time = core['reset']
            except Exception:
                logging.exception("Error Initializing GitHub API Rate Limit")
                await tornado.gen.sleep(30.)
            else:
                reset_time = time.ctime(self.gh_limit_reset_time)
                logging.info(
                    "GitHub API Rate Limit Initialized\n"
                    f"Rate Limit: {self.gh_rate_limit}\n"
                    f"Rate Limit Remaining: {self.gh_limit_remaining}\n"
                    f"Rate Limit Reset Time: {reset_time}, "
                    f"Seconds Since Epoch: {self.gh_limit_reset_time}")
                break
        self.gh_init_evt.set()

    async def github_api_request(self, url, etag=None, is_init=False):
        if not is_init:
            timeout = time.time() + 30.
            try:
                await self.gh_init_evt.wait(timeout)
            except Exception:
                raise self.server.error("Timeout while waiting for GitHub "
                                        "API Rate Limit initialization")
        if self.gh_limit_remaining == 0:
            curtime = time.time()
            if curtime < self.gh_limit_reset_time:
                raise self.server.error(
                    f"GitHub Rate Limit Reached\nRequest: {url}\n"
                    f"Limit Reset Time: {time.ctime(self.gh_limit_remaining)}")
        headers = {"Accept": "application/vnd.github.v3+json"}
        if etag is not None:
            headers['If-None-Match'] = etag
        retries = 5
        while retries:
            try:
                timeout = time.time() + 10.
                fut = self.http_client.fetch(url,
                                             headers=headers,
                                             connect_timeout=5.,
                                             request_timeout=5.,
                                             raise_error=False)
                resp = await tornado.gen.with_timeout(timeout, fut)
            except Exception:
                retries -= 1
                msg = f"Error Processing GitHub API request: {url}"
                if not retries:
                    raise self.server.error(msg)
                logging.exception(msg)
                await tornado.gen.sleep(1.)
                continue
            etag = resp.headers.get('etag', None)
            if etag is not None:
                if etag[:2] == "W/":
                    etag = etag[2:]
            logging.info("GitHub API Request Processed\n"
                         f"URL: {url}\n"
                         f"Response Code: {resp.code}\n"
                         f"Response Reason: {resp.reason}\n"
                         f"ETag: {etag}")
            if resp.code == 403:
                raise self.server.error(
                    f"Forbidden GitHub Request: {resp.reason}")
            elif resp.code == 304:
                logging.info(f"Github Request not Modified: {url}")
                return None
            if resp.code != 200:
                retries -= 1
                if not retries:
                    raise self.server.error(
                        f"Github Request failed: {resp.code} {resp.reason}")
                logging.info(
                    f"Github request error, {retries} retries remaining")
                await tornado.gen.sleep(1.)
                continue
            # Update rate limit on return success
            if 'X-Ratelimit-Limit' in resp.headers and not is_init:
                self.gh_rate_limit = int(resp.headers['X-Ratelimit-Limit'])
                self.gh_limit_remaining = int(
                    resp.headers['X-Ratelimit-Remaining'])
                self.gh_limit_reset_time = float(
                    resp.headers['X-Ratelimit-Reset'])
            decoded = json.loads(resp.body)
            decoded['etag'] = etag
            return decoded

    async def http_download_request(self, url):
        retries = 5
        while retries:
            try:
                timeout = time.time() + 130.
                fut = self.http_client.fetch(
                    url,
                    headers={"Accept": "application/zip"},
                    connect_timeout=5.,
                    request_timeout=120.)
                resp = await tornado.gen.with_timeout(timeout, fut)
            except Exception:
                retries -= 1
                logging.exception("Error Processing Download")
                if not retries:
                    raise
                await tornado.gen.sleep(1.)
                continue
            return resp.body

    def notify_update_response(self, resp, is_complete=False):
        resp = resp.strip()
        if isinstance(resp, bytes):
            resp = resp.decode()
        notification = {
            'message': resp,
            'application': None,
            'proc_id': None,
            'complete': is_complete
        }
        if self.current_update is not None:
            notification['application'] = self.current_update[0]
            notification['proc_id'] = self.current_update[1]
        self.server.send_event("update_manager:update_response", notification)

    def close(self):
        self.http_client.close()
        if self.refresh_cb is not None:
            self.refresh_cb.stop()
Пример #5
0
class UpdateManager:
    def __init__(self, config):
        self.server = config.get_server()
        self.config = config
        self.config.read_supplemental_config(SUPPLEMENTAL_CFG_PATH)
        auto_refresh_enabled = config.getboolean('enable_auto_refresh', False)
        self.distro = config.get('distro', "debian").lower()
        if self.distro not in SUPPORTED_DISTROS:
            raise config.error(f"Unsupported distro: {self.distro}")
        self.cmd_helper = CommandHelper(config)
        env = sys.executable
        mooncfg = self.config[f"update_manager static {self.distro} moonraker"]
        self.updaters = {
            "system": PackageUpdater(self.cmd_helper),
            "moonraker": GitUpdater(mooncfg, self.cmd_helper, MOONRAKER_PATH,
                                    env)
        }
        # TODO: Check for client config in [update_manager].  This is
        # deprecated and will be removed.
        client_repo = config.get("client_repo", None)
        if client_repo is not None:
            client_path = config.get("client_path")
            name = client_repo.split("/")[-1]
            self.updaters[name] = WebUpdater(
                {
                    'repo': client_repo,
                    'path': client_path
                }, self.cmd_helper)
        client_sections = self.config.get_prefix_sections(
            "update_manager client")
        for section in client_sections:
            cfg = self.config[section]
            name = section.split()[-1]
            if name in self.updaters:
                raise config.error("Client repo named %s already added" %
                                   (name, ))
            client_type = cfg.get("type")
            if client_type == "git_repo":
                self.updaters[name] = GitUpdater(cfg, self.cmd_helper)
            elif client_type == "web":
                self.updaters[name] = WebUpdater(cfg, self.cmd_helper)
            else:
                raise config.error("Invalid type '%s' for section [%s]" %
                                   (client_type, section))

        self.cmd_request_lock = Lock()
        self.is_refreshing = False

        # Auto Status Refresh
        self.last_auto_update_time = 0
        self.refresh_cb = None
        if auto_refresh_enabled:
            self.refresh_cb = PeriodicCallback(self._handle_auto_refresh,
                                               UPDATE_REFRESH_INTERVAL_MS)
            self.refresh_cb.start()

        self.server.register_endpoint("/machine/update/moonraker", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/klipper", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/system", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/client", ["POST"],
                                      self._handle_update_request)
        self.server.register_endpoint("/machine/update/status", ["GET"],
                                      self._handle_status_request)
        self.server.register_notification("update_manager:update_response")
        self.server.register_notification("update_manager:update_refreshed")

        # Register Ready Event
        self.server.register_event_handler("server:klippy_identified",
                                           self._set_klipper_repo)
        # Initialize GitHub API Rate Limits and configured updaters
        IOLoop.current().spawn_callback(self._initalize_updaters,
                                        list(self.updaters.values()))

    async def _initalize_updaters(self, initial_updaters):
        self.is_refreshing = True
        await self.cmd_helper.init_api_rate_limit()
        for updater in initial_updaters:
            if isinstance(updater, PackageUpdater):
                ret = updater.refresh(False)
            else:
                ret = updater.refresh()
            if asyncio.iscoroutine(ret):
                await ret
        self.is_refreshing = False

    async def _set_klipper_repo(self):
        kinfo = self.server.get_klippy_info()
        if not kinfo:
            logging.info("No valid klippy info received")
            return
        kpath = kinfo['klipper_path']
        env = kinfo['python_path']
        kupdater = self.updaters.get('klipper', None)
        if kupdater is not None and kupdater.repo_path == kpath and \
                kupdater.env == env:
            # Current Klipper Updater is valid
            return
        kcfg = self.config[f"update_manager static {self.distro} klipper"]
        self.updaters['klipper'] = GitUpdater(kcfg, self.cmd_helper, kpath,
                                              env)
        await self.updaters['klipper'].refresh()

    async def _check_klippy_printing(self):
        klippy_apis = self.server.lookup_plugin('klippy_apis')
        result = await klippy_apis.query_objects({'print_stats': None},
                                                 default={})
        pstate = result.get('print_stats', {}).get('state', "")
        return pstate.lower() == "printing"

    async def _handle_auto_refresh(self):
        if await self._check_klippy_printing():
            # Don't Refresh during a print
            logging.info("Klippy is printing, auto refresh aborted")
            return
        cur_time = time.time()
        cur_hour = time.localtime(cur_time).tm_hour
        time_diff = cur_time - self.last_auto_update_time
        # Update packages if it has been more than 12 hours
        # and the local time is between 12AM and 5AM
        if time_diff < MIN_REFRESH_TIME or cur_hour >= MAX_PKG_UPDATE_HOUR:
            # Not within the update time window
            return
        self.last_auto_update_time = cur_time
        vinfo = {}
        need_refresh_all = not self.is_refreshing
        async with self.cmd_request_lock:
            self.is_refreshing = True
            try:
                for name, updater in list(self.updaters.items()):
                    if need_refresh_all:
                        ret = updater.refresh()
                        if asyncio.iscoroutine(ret):
                            await ret
                    if hasattr(updater, "get_update_status"):
                        vinfo[name] = updater.get_update_status()
            except Exception:
                logging.exception("Unable to Refresh Status")
                return
            finally:
                self.is_refreshing = False
        uinfo = self.cmd_helper.get_rate_limit_stats()
        uinfo['version_info'] = vinfo
        uinfo['busy'] = self.cmd_helper.is_update_busy()
        self.server.send_event("update_manager:update_refreshed", uinfo)

    async def _handle_update_request(self, web_request):
        if await self._check_klippy_printing():
            raise self.server.error("Update Refused: Klippy is printing")
        app = web_request.get_endpoint().split("/")[-1]
        if app == "client":
            app = web_request.get('name')
        inc_deps = web_request.get_boolean('include_deps', False)
        if self.cmd_helper.is_app_updating(app):
            return f"Object {app} is currently being updated"
        updater = self.updaters.get(app, None)
        if updater is None:
            raise self.server.error(f"Updater {app} not available")
        async with self.cmd_request_lock:
            self.cmd_helper.set_update_info(app, id(web_request))
            try:
                await updater.update(inc_deps)
            except Exception as e:
                self.cmd_helper.notify_update_response(f"Error updating {app}")
                self.cmd_helper.notify_update_response(str(e),
                                                       is_complete=True)
                raise
            finally:
                self.cmd_helper.clear_update_info()
        return "ok"

    async def _handle_status_request(self, web_request):
        check_refresh = web_request.get_boolean('refresh', False)
        # Don't refresh if a print is currently in progress or
        # if an update is in progress.  Just return the current
        # state
        if self.cmd_helper.is_update_busy() or \
                await self._check_klippy_printing():
            check_refresh = False
        need_refresh = False
        if check_refresh:
            # If there is an outstanding request processing a
            # refresh, we don't need to do it again.
            need_refresh = not self.is_refreshing
            await self.cmd_request_lock.acquire()
            self.is_refreshing = True
        vinfo = {}
        try:
            for name, updater in list(self.updaters.items()):
                await updater.check_initialized(120.)
                if need_refresh:
                    ret = updater.refresh()
                    if asyncio.iscoroutine(ret):
                        await ret
                if hasattr(updater, "get_update_status"):
                    vinfo[name] = updater.get_update_status()
        except Exception:
            raise
        finally:
            if check_refresh:
                self.is_refreshing = False
                self.cmd_request_lock.release()
        ret = self.cmd_helper.get_rate_limit_stats()
        ret['version_info'] = vinfo
        ret['busy'] = self.cmd_helper.is_update_busy()
        return ret

    def close(self):
        self.cmd_helper.close()
        if self.refresh_cb is not None:
            self.refresh_cb.stop()
Пример #6
0
class Stream(object):
    def __init__(self, conn, stream_id, delegate, context=None):
        self.conn = conn
        self.stream_id = stream_id
        self.set_delegate(delegate)
        self.context = context
        self.finish_future = Future()
        self.write_lock = Lock()
        from tornado.util import ObjectDict
        # TODO: remove
        self.stream = ObjectDict(io_loop=IOLoop.current(),
                                 close=conn.stream.close)
        self._incoming_content_remaining = None
        self._outgoing_content_remaining = None
        self._delegate_started = False
        self.window = Window(
            conn.window, stream_id,
            conn.setting(constants.Setting.INITIAL_WINDOW_SIZE))
        self._header_frames = []
        self._phase = constants.HTTPPhase.HEADERS

    def set_delegate(self, delegate):
        self.orig_delegate = self.delegate = delegate
        if self.conn.params.decompress:
            self.delegate = _GzipMessageDelegate(delegate,
                                                 self.conn.params.chunk_size)

    def handle_frame(self, frame):
        if frame.type == constants.FrameType.PRIORITY:
            self._handle_priority_frame(frame)
            return
        elif frame.type == constants.FrameType.RST_STREAM:
            self._handle_rst_stream_frame(frame)
            return
        elif frame.type == constants.FrameType.WINDOW_UPDATE:
            self._handle_window_update_frame(frame)
            return
        elif frame.type in (constants.FrameType.SETTINGS,
                            constants.FrameType.GOAWAY,
                            constants.FrameType.PUSH_PROMISE):
            raise Exception("invalid frame type %s for stream", frame.type)

        if self.finish_future.done():
            raise StreamError(self.stream_id,
                              constants.ErrorCode.STREAM_CLOSED)

        if frame.type == constants.FrameType.HEADERS:
            self._handle_headers_frame(frame)
        elif frame.type == constants.FrameType.CONTINUATION:
            self._handle_continuation_frame(frame)
        elif frame.type == constants.FrameType.DATA:
            self._handle_data_frame(frame)
        # Unknown frame types are silently discarded, unless they break
        # the rule that nothing can come between HEADERS and CONTINUATION.

    def needs_continuation(self):
        return bool(self._header_frames)

    def _handle_headers_frame(self, frame):
        if self._phase == constants.HTTPPhase.BODY:
            self._phase = constants.HTTPPhase.TRAILERS
        frame = frame.without_padding()
        self._header_frames.append(frame)
        self._check_header_length()
        if frame.flags & constants.FrameFlag.END_HEADERS:
            self._parse_headers()

    def _handle_continuation_frame(self, frame):
        if not self._header_frames:
            raise ConnectionError(constants.ErrorCode.PROTOCOL_ERROR,
                                  "CONTINUATION without HEADERS")
        self._header_frames.append(frame)
        self._check_header_length()
        if frame.flags & constants.FrameFlag.END_HEADERS:
            self._parse_headers()

    def _check_header_length(self):
        if (sum(len(f.data) for f in self._header_frames) >
                self.conn.params.max_header_size):
            if self.conn.is_client:
                # TODO: Need tests for client side of headers-too-large.
                # What's the best way to send an error?
                self.delegate.on_connection_close()
            else:
                # write_headers needs a start line so it can tell
                # whether this is a HEAD or not. If we're rejecting
                # the headers we can't know so just make something up.
                # Note that this means the error response body MUST be
                # zero bytes so it doesn't matter whether the client
                # sent a HEAD or a GET.
                self._request_start_line = RequestStartLine(
                    'GET', '/', 'HTTP/2.0')
                start_line = ResponseStartLine('HTTP/2.0', 431,
                                               'Headers too large')
                self.write_headers(start_line, HTTPHeaders())
                self.finish()
            return

    def _parse_headers(self):
        frame = self._header_frames[0]
        data = b''.join(f.data for f in self._header_frames)
        self._header_frames = []
        if frame.flags & constants.FrameFlag.PRIORITY:
            # TODO: support PRIORITY and PADDING.
            # This is just enough to cover an error case tested in h2spec.
            stream_dep, weight = struct.unpack('>ib', data[:5])
            data = data[5:]
            # strip off the "exclusive" bit
            stream_dep = stream_dep & 0x7fffffff
            if stream_dep == frame.stream_id:
                raise ConnectionError(constants.ErrorCode.PROTOCOL_ERROR,
                                      "stream cannot depend on itself")
        pseudo_headers = {}
        headers = HTTPHeaders()
        try:
            # Pseudo-headers must come before any regular headers,
            # and only in the first HEADERS phase.
            has_regular_header = bool(
                self._phase == constants.HTTPPhase.TRAILERS)
            for k, v, idx in self.conn.hpack_decoder.decode(bytearray(data)):
                if k != k.lower():
                    # RFC section 8.1.2
                    raise StreamError(self.stream_id,
                                      constants.ErrorCode.PROTOCOL_ERROR)
                if k.startswith(b':'):
                    if self.conn.is_client:
                        valid_pseudo_headers = (b':status', )
                    else:
                        valid_pseudo_headers = (b':method', b':scheme',
                                                b':authority', b':path')
                    if (has_regular_header or k not in valid_pseudo_headers
                            or native_str(k) in pseudo_headers):
                        raise StreamError(self.stream_id,
                                          constants.ErrorCode.PROTOCOL_ERROR)
                    pseudo_headers[native_str(k)] = native_str(v)
                    if k == b":authority":
                        headers.add("Host", native_str(v))
                else:
                    headers.add(native_str(k), native_str(v))
                    has_regular_header = True
        except HpackError:
            raise ConnectionError(constants.ErrorCode.COMPRESSION_ERROR)
        if self._phase == constants.HTTPPhase.HEADERS:
            self._start_request(pseudo_headers, headers)
        elif self._phase == constants.HTTPPhase.TRAILERS:
            # TODO: support trailers
            pass
        if (not self._maybe_end_stream(frame.flags)
                and self._phase == constants.HTTPPhase.TRAILERS):
            # The frame that finishes the trailers must also finish
            # the stream.
            raise StreamError(self.stream_id,
                              constants.ErrorCode.PROTOCOL_ERROR)

    def _start_request(self, pseudo_headers, headers):
        if "connection" in headers:
            raise ConnectionError(constants.ErrorCode.PROTOCOL_ERROR,
                                  "connection header should not be present")
        if "te" in headers and headers["te"] != "trailers":
            raise StreamError(self.stream_id,
                              constants.ErrorCode.PROTOCOL_ERROR)
        if self.conn.is_client:
            status = int(pseudo_headers[':status'])
            start_line = ResponseStartLine('HTTP/2.0', status,
                                           responses.get(status, ''))
        else:
            for k in (':method', ':scheme', ':path'):
                if k not in pseudo_headers:
                    raise StreamError(self.stream_id,
                                      constants.ErrorCode.PROTOCOL_ERROR)
            start_line = RequestStartLine(pseudo_headers[':method'],
                                          pseudo_headers[':path'], 'HTTP/2.0')
            self._request_start_line = start_line

        if (self.conn.is_client and (self._request_start_line.method == 'HEAD'
                                     or start_line.code == 304)):
            self._incoming_content_remaining = 0
        elif "content-length" in headers:
            self._incoming_content_remaining = int(headers["content-length"])

        if not self.conn.is_client or status >= 200:
            self._phase = constants.HTTPPhase.BODY

        self._delegate_started = True
        self.delegate.headers_received(start_line, headers)

    def _handle_data_frame(self, frame):
        if self._header_frames:
            raise ConnectionError(constants.ErrorCode.PROTOCOL_ERROR,
                                  "DATA without END_HEADERS")
        if self._phase == constants.HTTPPhase.TRAILERS:
            raise ConnectionError(constants.ErrorCode.PROTOCOL_ERROR,
                                  "DATA after trailers")
        self._phase = constants.HTTPPhase.BODY
        frame = frame.without_padding()
        if self._incoming_content_remaining is not None:
            self._incoming_content_remaining -= len(frame.data)
            if self._incoming_content_remaining < 0:
                raise StreamError(self.stream_id,
                                  constants.ErrorCode.PROTOCOL_ERROR)
        if frame.data and self._delegate_started:
            future = self.delegate.data_received(frame.data)
            if future is None:
                self._send_window_update(len(frame.data))
            else:
                IOLoop.current().add_future(
                    future,
                    lambda f: self._send_window_update(len(frame.data)))
        self._maybe_end_stream(frame.flags)

    def _send_window_update(self, amount):
        encoded = struct.pack('>I', amount)
        for stream_id in (0, self.stream_id):
            self.conn._write_frame(
                Frame(constants.FrameType.WINDOW_UPDATE, 0, stream_id,
                      encoded))

    def _maybe_end_stream(self, flags):
        if flags & constants.FrameFlag.END_STREAM:
            if (self._incoming_content_remaining is not None
                    and self._incoming_content_remaining != 0):
                raise StreamError(self.stream_id,
                                  constants.ErrorCode.PROTOCOL_ERROR)
            if self._delegate_started:
                self._delegate_started = False
                self.delegate.finish()
            self.finish_future.set_result(None)
            return True
        return False

    def _handle_priority_frame(self, frame):
        # TODO: implement priority
        if len(frame.data) != 5:
            raise StreamError(self.stream_id,
                              constants.ErrorCode.FRAME_SIZE_ERROR)

    def _handle_rst_stream_frame(self, frame):
        if len(frame.data) != 4:
            raise ConnectionError(constants.ErrorCode.FRAME_SIZE_ERROR)
        # TODO: expose error code?
        if self._delegate_started:
            self.delegate.on_connection_close()

    def _handle_window_update_frame(self, frame):
        self.window.apply_window_update(frame)

    def set_close_callback(self, callback):
        # TODO: this shouldn't be necessary
        pass

    def reset(self):
        self.conn._write_frame(
            Frame(constants.FrameType.RST_STREAM, 0, self.stream_id,
                  b'\x00\x00\x00\x00'))

    @_reset_on_error
    def write_headers(self, start_line, headers, chunk=None, callback=None):
        if (not self.conn.is_client
                and (self._request_start_line.method == 'HEAD'
                     or start_line.code == 304)):
            self._outgoing_content_remaining = 0
        elif 'Content-Length' in headers:
            self._outgoing_content_remaining = int(headers['Content-Length'])
        header_list = []
        if self.conn.is_client:
            self._request_start_line = start_line
            header_list.append((b':method', utf8(start_line.method),
                                constants.HeaderIndexMode.YES))
            header_list.append(
                (b':scheme', b'https', constants.HeaderIndexMode.YES))
            header_list.append((b':path', utf8(start_line.path),
                                constants.HeaderIndexMode.NO))
        else:
            header_list.append((b':status', utf8(str(start_line.code)),
                                constants.HeaderIndexMode.YES))
        for k, v in headers.get_all():
            k = utf8(k.lower())
            if k == b"connection":
                # Remove the implicit "connection: close", which is not
                # allowed in http2.
                # TODO: move the responsibility for this from httpclient
                # to http1connection?
                continue
            header_list.append((k, utf8(v), constants.HeaderIndexMode.YES))
        data = bytes(self.conn.hpack_encoder.encode(header_list))
        frame = Frame(constants.FrameType.HEADERS,
                      constants.FrameFlag.END_HEADERS, self.stream_id, data)
        self.conn._write_frame(frame)

        return self.write(chunk, callback)

    @_reset_on_error
    def write(self, chunk, callback=None):
        if chunk:
            if self._outgoing_content_remaining is not None:
                self._outgoing_content_remaining -= len(chunk)
                if self._outgoing_content_remaining < 0:
                    raise HTTPOutputError(
                        "Tried to write more data than Content-Length")
        return self._write_chunk(chunk, callback)

    @gen.coroutine
    def _write_chunk(self, chunk, callback=None):
        try:
            if chunk:
                yield self.write_lock.acquire()
                while chunk:
                    bytes_to_write = min(
                        len(chunk),
                        self.conn.setting(constants.Setting.MAX_FRAME_SIZE))
                    allowance = yield self.window.consume(bytes_to_write)

                    yield self.conn._write_frame(
                        Frame(constants.FrameType.DATA, 0, self.stream_id,
                              chunk[:allowance]))
                    chunk = chunk[allowance:]
                self.write_lock.release()
            if callback is not None:
                callback()
        except Exception:
            self.reset()
            raise

    @_reset_on_error
    def finish(self):
        if (self._outgoing_content_remaining is not None
                and self._outgoing_content_remaining != 0):
            raise HTTPOutputError(
                "Tried to write %d bytes less than Content-Length" %
                self._outgoing_content_remaining)
        return self._write_end_stream()

    @gen.coroutine
    def _write_end_stream(self):
        # Callers are not required to wait for write() before calling finish,
        # so we must manually lock.
        yield self.write_lock.acquire()
        try:
            self.conn._write_frame(
                Frame(constants.FrameType.DATA, constants.FrameFlag.END_STREAM,
                      self.stream_id, b''))
        except Exception:
            self.reset()
            raise
        finally:
            self.write_lock.release()

    def read_response(self, delegate):
        assert delegate is self.orig_delegate, 'cannot change delegate'
        return self.finish_future
Пример #7
0
class CacheBase(Configurable):
    """借助Configurable实现单例
    1. 操作应该都是异步的
    >>> def view(self):
    ...     value = yield cache.get(key)
    """
    @classmethod
    def cached_instances(cls):
        attr_name = '_cached_instances_dict_' + cls.__name__
        if not hasattr(cls, attr_name):
            setattr(cls, attr_name, weakref.WeakKeyDictionary())
        return getattr(cls, attr_name)

    def __new__(cls, io_loop=None, force_instance=False, **kwargs):
        io_loop = io_loop or IOLoop.current()
        if force_instance:
            instance_cache = None
        else:
            instance_cache = cls.cached_instances()
        if instance_cache is not None and io_loop in instance_cache:
            return instance_cache[io_loop]
        instance = super(CacheBase, cls).__new__(cls,
                                                 io_loop=io_loop,
                                                 **kwargs)
        # Make sure the instance knows which cache to remove itself from.
        # It can't simply call _async_clients() because we may be in
        # __new__(AsyncHTTPClient) but instance.__class__ may be
        # SimpleAsyncHTTPClient.
        instance._instance_cache = instance_cache
        if instance_cache is not None:
            instance_cache[instance.io_loop] = instance
        return instance

    def _make_key(self, key, version=None):
        if version is None:
            version = self.version

        new_key = self.key_func(key, self.key_prefix, version)
        return new_key

    def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT):
        """
        Returns the timeout value usable by this backend based upon the provided
        timeout.
        """
        if timeout == DEFAULT_TIMEOUT:
            timeout = self.default_timeout
        elif timeout == 0:
            # ticket 21147 - avoid time.time() related precision issues
            timeout = -1

        return None if timeout is None else self.io_loop.time() + timeout

    def get(self, key, default=None, version=None):
        raise NotImplementedError(
            'subclasses of BaseCache must provide an add() method')

    def get_sync(self, key, default=None, version=None):
        raise NotImplementedError(
            'subclasses of BaseCache must provide an get_sync() method')

    def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
        raise NotImplementedError(
            'subclasses of BaseCache must provide a set() method')

    def set_sync(self, key, default=None, version=None):
        raise NotImplementedError(
            'subclasses of BaseCache must provide a set_sync() method')

    def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
        raise NotImplementedError(
            'subclasses of BaseCache must provide an add() method')

    def initialize(self, io_loop, defaults=None):
        self.io_loop = io_loop
        self.key_func = get_key_func(getattr(options, 'key_func', None))
        self.default_timeout = getattr(options, 'cache_time', 300)
        self.version = getattr(options, 'version', 1)
        self.key_prefix = getattr(options, 'key_prefix', 'cache')
        self.defaults = dict()
        self._lock = Lock()
        if defaults is not None:
            self.defaults.update(defaults)

    def lock(self, timeout=500):
        return self._lock.acquire()

    def release(self):
        return self._lock.release()

    @classmethod
    def configurable_base(cls):
        return CacheBase
Пример #8
0
class Stream(object):
    def __init__(self, conn, stream_id, delegate, context=None):
        self.conn = conn
        self.stream_id = stream_id
        self.set_delegate(delegate)
        self.context = context
        self.finish_future = Future()
        self.write_lock = Lock()
        from tornado.util import ObjectDict
        # TODO: remove
        self.stream = ObjectDict(io_loop=IOLoop.current(), close=conn.stream.close)
        self._incoming_content_remaining = None
        self._outgoing_content_remaining = None
        self._delegate_started = False
        self.window = Window(conn.window, stream_id,
                             conn.setting(constants.Setting.INITIAL_WINDOW_SIZE))
        self._header_frames = []
        self._phase = constants.HTTPPhase.HEADERS

    def set_delegate(self, delegate):
        self.orig_delegate = self.delegate = delegate
        if self.conn.params.decompress:
            self.delegate = _GzipMessageDelegate(delegate, self.conn.params.chunk_size)

    def handle_frame(self, frame):
        if frame.type == constants.FrameType.PRIORITY:
            self._handle_priority_frame(frame)
            return
        elif frame.type == constants.FrameType.RST_STREAM:
            self._handle_rst_stream_frame(frame)
            return
        elif frame.type == constants.FrameType.WINDOW_UPDATE:
            self._handle_window_update_frame(frame)
            return
        elif frame.type in (constants.FrameType.SETTINGS,
                            constants.FrameType.GOAWAY,
                            constants.FrameType.PUSH_PROMISE):
            raise Exception("invalid frame type %s for stream", frame.type)

        if self.finish_future.done():
            raise StreamError(self.stream_id, constants.ErrorCode.STREAM_CLOSED)

        if frame.type == constants.FrameType.HEADERS:
            self._handle_headers_frame(frame)
        elif frame.type == constants.FrameType.CONTINUATION:
            self._handle_continuation_frame(frame)
        elif frame.type == constants.FrameType.DATA:
            self._handle_data_frame(frame)
        # Unknown frame types are silently discarded, unless they break
        # the rule that nothing can come between HEADERS and CONTINUATION.

    def needs_continuation(self):
        return bool(self._header_frames)

    def _handle_headers_frame(self, frame):
        if self._phase == constants.HTTPPhase.BODY:
            self._phase = constants.HTTPPhase.TRAILERS
        frame = frame.without_padding()
        self._header_frames.append(frame)
        self._check_header_length()
        if frame.flags & constants.FrameFlag.END_HEADERS:
            self._parse_headers()

    def _handle_continuation_frame(self, frame):
        if not self._header_frames:
            raise ConnectionError(constants.ErrorCode.PROTOCOL_ERROR,
                                  "CONTINUATION without HEADERS")
        self._header_frames.append(frame)
        self._check_header_length()
        if frame.flags & constants.FrameFlag.END_HEADERS:
            self._parse_headers()

    def _check_header_length(self):
        if (sum(len(f.data) for f in self._header_frames) >
                self.conn.params.max_header_size):
            if self.conn.is_client:
                # TODO: Need tests for client side of headers-too-large.
                # What's the best way to send an error?
                self.delegate.on_connection_close()
            else:
                # write_headers needs a start line so it can tell
                # whether this is a HEAD or not. If we're rejecting
                # the headers we can't know so just make something up.
                # Note that this means the error response body MUST be
                # zero bytes so it doesn't matter whether the client
                # sent a HEAD or a GET.
                self._request_start_line = RequestStartLine('GET', '/', 'HTTP/2.0')
                start_line = ResponseStartLine('HTTP/2.0', 431, 'Headers too large')
                self.write_headers(start_line, HTTPHeaders())
                self.finish()
            return

    def _parse_headers(self):
        frame = self._header_frames[0]
        data = b''.join(f.data for f in self._header_frames)
        self._header_frames = []
        if frame.flags & constants.FrameFlag.PRIORITY:
            # TODO: support PRIORITY and PADDING.
            # This is just enough to cover an error case tested in h2spec.
            stream_dep, weight = struct.unpack('>ib', data[:5])
            data = data[5:]
            # strip off the "exclusive" bit
            stream_dep = stream_dep & 0x7fffffff
            if stream_dep == frame.stream_id:
                raise ConnectionError(constants.ErrorCode.PROTOCOL_ERROR,
                                      "stream cannot depend on itself")
        pseudo_headers = {}
        headers = HTTPHeaders()
        try:
            # Pseudo-headers must come before any regular headers,
            # and only in the first HEADERS phase.
            has_regular_header = bool(self._phase == constants.HTTPPhase.TRAILERS)
            for k, v, idx in self.conn.hpack_decoder.decode(bytearray(data)):
                if k != k.lower():
                    # RFC section 8.1.2
                    raise StreamError(self.stream_id,
                                      constants.ErrorCode.PROTOCOL_ERROR)
                if k.startswith(b':'):
                    if self.conn.is_client:
                        valid_pseudo_headers = (b':status',)
                    else:
                        valid_pseudo_headers = (b':method', b':scheme',
                                                b':authority', b':path')
                    if (has_regular_header or
                            k not in valid_pseudo_headers or
                            native_str(k) in pseudo_headers):
                        raise StreamError(self.stream_id,
                                          constants.ErrorCode.PROTOCOL_ERROR)
                    pseudo_headers[native_str(k)] = native_str(v)
                    if k == b":authority":
                        headers.add("Host", native_str(v))
                else:
                    headers.add(native_str(k),  native_str(v))
                    has_regular_header = True
        except HpackError:
            raise ConnectionError(constants.ErrorCode.COMPRESSION_ERROR)
        if self._phase == constants.HTTPPhase.HEADERS:
            self._start_request(pseudo_headers, headers)
        elif self._phase == constants.HTTPPhase.TRAILERS:
            # TODO: support trailers
            pass
        if (not self._maybe_end_stream(frame.flags) and
                self._phase == constants.HTTPPhase.TRAILERS):
            # The frame that finishes the trailers must also finish
            # the stream.
            raise StreamError(self.stream_id, constants.ErrorCode.PROTOCOL_ERROR)

    def _start_request(self, pseudo_headers, headers):
        if "connection" in headers:
            raise ConnectionError(constants.ErrorCode.PROTOCOL_ERROR,
                                  "connection header should not be present")
        if "te" in headers and headers["te"] != "trailers":
            raise StreamError(self.stream_id, constants.ErrorCode.PROTOCOL_ERROR)
        if self.conn.is_client:
            status = int(pseudo_headers[':status'])
            start_line = ResponseStartLine('HTTP/2.0', status, responses.get(status, ''))
        else:
            for k in (':method', ':scheme', ':path'):
                if k not in pseudo_headers:
                    raise StreamError(self.stream_id,
                                      constants.ErrorCode.PROTOCOL_ERROR)
            start_line = RequestStartLine(pseudo_headers[':method'],
                                          pseudo_headers[':path'], 'HTTP/2.0')
            self._request_start_line = start_line

        if (self.conn.is_client and
            (self._request_start_line.method == 'HEAD' or
             start_line.code == 304)):
            self._incoming_content_remaining = 0
        elif "content-length" in headers:
            self._incoming_content_remaining = int(headers["content-length"])

        if not self.conn.is_client or status >= 200:
            self._phase = constants.HTTPPhase.BODY

        self._delegate_started = True
        self.delegate.headers_received(start_line, headers)

    def _handle_data_frame(self, frame):
        if self._header_frames:
            raise ConnectionError(constants.ErrorCode.PROTOCOL_ERROR,
                                  "DATA without END_HEADERS")
        if self._phase == constants.HTTPPhase.TRAILERS:
            raise ConnectionError(constants.ErrorCode.PROTOCOL_ERROR,
                                  "DATA after trailers")
        self._phase = constants.HTTPPhase.BODY
        frame = frame.without_padding()
        if self._incoming_content_remaining is not None:
            self._incoming_content_remaining -= len(frame.data)
            if self._incoming_content_remaining < 0:
                raise StreamError(self.stream_id, constants.ErrorCode.PROTOCOL_ERROR)
        if frame.data and self._delegate_started:
            future = self.delegate.data_received(frame.data)
            if future is None:
                self._send_window_update(len(frame.data))
            else:
                IOLoop.current().add_future(
                    future, lambda f: self._send_window_update(len(frame.data)))
        self._maybe_end_stream(frame.flags)

    def _send_window_update(self, amount):
        encoded = struct.pack('>I', amount)
        for stream_id in (0, self.stream_id):
            self.conn._write_frame(Frame(
                constants.FrameType.WINDOW_UPDATE, 0,
                stream_id, encoded))

    def _maybe_end_stream(self, flags):
        if flags & constants.FrameFlag.END_STREAM:
            if (self._incoming_content_remaining is not None and
                    self._incoming_content_remaining != 0):
                raise StreamError(self.stream_id, constants.ErrorCode.PROTOCOL_ERROR)
            if self._delegate_started:
                self._delegate_started = False
                self.delegate.finish()
            self.finish_future.set_result(None)
            return True
        return False

    def _handle_priority_frame(self, frame):
        # TODO: implement priority
        if len(frame.data) != 5:
            raise StreamError(self.stream_id,
                              constants.ErrorCode.FRAME_SIZE_ERROR)

    def _handle_rst_stream_frame(self, frame):
        if len(frame.data) != 4:
            raise ConnectionError(constants.ErrorCode.FRAME_SIZE_ERROR)
        # TODO: expose error code?
        if self._delegate_started:
            self.delegate.on_connection_close()

    def _handle_window_update_frame(self, frame):
        self.window.apply_window_update(frame)

    def set_close_callback(self, callback):
        # TODO: this shouldn't be necessary
        pass

    def reset(self):
        self.conn._write_frame(Frame(constants.FrameType.RST_STREAM,
                                     0, self.stream_id, b'\x00\x00\x00\x00'))

    @_reset_on_error
    def write_headers(self, start_line, headers, chunk=None, callback=None):
        if (not self.conn.is_client and
            (self._request_start_line.method == 'HEAD' or
             start_line.code == 304)):
            self._outgoing_content_remaining = 0
        elif 'Content-Length' in headers:
            self._outgoing_content_remaining = int(headers['Content-Length'])
        header_list = []
        if self.conn.is_client:
            self._request_start_line = start_line
            header_list.append((b':method', utf8(start_line.method),
                                constants.HeaderIndexMode.YES))
            header_list.append((b':scheme', b'https',
                                constants.HeaderIndexMode.YES))
            header_list.append((b':path', utf8(start_line.path),
                                constants.HeaderIndexMode.NO))
        else:
            header_list.append((b':status', utf8(str(start_line.code)),
                                constants.HeaderIndexMode.YES))
        for k, v in headers.get_all():
            k = utf8(k.lower())
            if k == b"connection":
                # Remove the implicit "connection: close", which is not
                # allowed in http2.
                # TODO: move the responsibility for this from httpclient
                # to http1connection?
                continue
            header_list.append((k, utf8(v),
                                constants.HeaderIndexMode.YES))
        data = bytes(self.conn.hpack_encoder.encode(header_list))
        frame = Frame(constants.FrameType.HEADERS,
                      constants.FrameFlag.END_HEADERS, self.stream_id,
                      data)
        self.conn._write_frame(frame)

        return self.write(chunk, callback)

    @_reset_on_error
    def write(self, chunk, callback=None):
        if chunk:
            if self._outgoing_content_remaining is not None:
                self._outgoing_content_remaining -= len(chunk)
                if self._outgoing_content_remaining < 0:
                    raise HTTPOutputError(
                        "Tried to write more data than Content-Length")
        return self._write_chunk(chunk, callback)

    @gen.coroutine
    def _write_chunk(self, chunk, callback=None):
        try:
            if chunk:
                yield self.write_lock.acquire()
                while chunk:
                    allowance = yield self.window.consume(len(chunk))

                    yield self.conn._write_frame(
                        Frame(constants.FrameType.DATA, 0,
                              self.stream_id, chunk[:allowance]))
                    chunk = chunk[allowance:]
                self.write_lock.release()
            if callback is not None:
                callback()
        except Exception:
            self.reset()
            raise

    @_reset_on_error
    def finish(self):
        if (self._outgoing_content_remaining is not None and
                self._outgoing_content_remaining != 0):
            raise HTTPOutputError(
                "Tried to write %d bytes less than Content-Length" %
                self._outgoing_content_remaining)
        return self._write_end_stream()

    @gen.coroutine
    def _write_end_stream(self):
        # Callers are not required to wait for write() before calling finish,
        # so we must manually lock.
        yield self.write_lock.acquire()
        try:
            self.conn._write_frame(Frame(constants.FrameType.DATA,
                                         constants.FrameFlag.END_STREAM,
                                         self.stream_id, b''))
        except Exception:
            self.reset()
            raise
        finally:
            self.write_lock.release()

    def read_response(self, delegate):
        assert delegate is self.orig_delegate, 'cannot change delegate'
        return self.finish_future