Exemple #1
0
    def __init__(self, ioloop: IOLoop, script_path: str,
                 command_line: Optional[str]):
        """Create the server. It won't be started yet."""
        if Server._singleton is not None:
            raise RuntimeError(
                "Server already initialized. Use .get_current() instead")

        Server._singleton = self

        _set_tornado_log_levels()

        self._ioloop = ioloop
        self._script_path = script_path
        self._command_line = command_line

        # Mapping of ReportSession.id -> SessionInfo.
        self._session_info_by_id: Dict[str, SessionInfo] = {}

        self._must_stop = tornado.locks.Event()
        self._state = State.INITIAL
        self._message_cache = ForwardMsgCache()
        self._uploaded_file_mgr = UploadedFileManager()
        self._uploaded_file_mgr.on_files_updated.connect(self.on_files_updated)
        self._report: Optional[Report] = None
        self._preheated_session_id: Optional[str] = None
        self._has_connection = tornado.locks.Condition()
        self._need_send_data = tornado.locks.Event()

        # StatsManager
        self._stats_mgr = StatsManager()
        self._stats_mgr.register_provider(get_memo_stats_provider())
        self._stats_mgr.register_provider(get_singleton_stats_provider())
        self._stats_mgr.register_provider(_mem_caches)
        self._stats_mgr.register_provider(self._message_cache)
        self._stats_mgr.register_provider(in_memory_file_manager)
        self._stats_mgr.register_provider(self._uploaded_file_mgr)
        self._stats_mgr.register_provider(
            SessionStateStatProvider(self._session_info_by_id))
Exemple #2
0
    def test_get_stats(self):
        """StatsManager.get_stats should return all providers' stats."""
        manager = StatsManager()
        provider1 = MockStatsProvider()
        provider2 = MockStatsProvider()
        manager.register_provider(provider1)
        manager.register_provider(provider2)

        # No stats
        self.assertEqual([], manager.get_stats())

        # Some stats
        provider1.stats = [
            CacheStat("provider1", "foo", 1),
            CacheStat("provider1", "bar", 2),
        ]

        provider2.stats = [
            CacheStat("provider2", "baz", 3),
            CacheStat("provider2", "qux", 4),
        ]

        self.assertEqual(provider1.stats + provider2.stats,
                         manager.get_stats())
Exemple #3
0
class Server:
    _singleton: Optional["Server"] = None

    @classmethod
    def get_current(cls) -> "Server":
        """
        Returns
        -------
        Server
            The singleton Server object.
        """
        if Server._singleton is None:
            raise RuntimeError("Server has not been initialized yet")

        return Server._singleton

    def __init__(self, ioloop: IOLoop, script_path: str, command_line: Optional[str]):
        """Create the server. It won't be started yet."""
        if Server._singleton is not None:
            raise RuntimeError("Server already initialized. Use .get_current() instead")

        Server._singleton = self

        _set_tornado_log_levels()

        self._ioloop = ioloop
        self._script_path = script_path
        self._command_line = command_line if command_line is not None else ""

        # Mapping of AppSession.id -> SessionInfo.
        self._session_info_by_id: Dict[str, SessionInfo] = {}

        self._must_stop = tornado.locks.Event()
        self._state = State.INITIAL
        self._message_cache = ForwardMsgCache()
        self._uploaded_file_mgr = UploadedFileManager()
        self._uploaded_file_mgr.on_files_updated.connect(self.on_files_updated)
        self._session_data: Optional[SessionData] = None
        self._has_connection = tornado.locks.Condition()
        self._need_send_data = tornado.locks.Event()

        # StatsManager
        self._stats_mgr = StatsManager()
        self._stats_mgr.register_provider(get_memo_stats_provider())
        self._stats_mgr.register_provider(get_singleton_stats_provider())
        self._stats_mgr.register_provider(_mem_caches)
        self._stats_mgr.register_provider(self._message_cache)
        self._stats_mgr.register_provider(in_memory_file_manager)
        self._stats_mgr.register_provider(self._uploaded_file_mgr)
        self._stats_mgr.register_provider(
            SessionStateStatProvider(self._session_info_by_id)
        )

    def __repr__(self) -> str:
        return util.repr_(self)

    @property
    def script_path(self) -> str:
        return self._script_path

    def get_session_by_id(self, session_id: str) -> Optional[AppSession]:
        """Return the AppSession corresponding to the given id, or None if
        no such session exists."""
        session_info = self._get_session_info(session_id)
        if session_info is None:
            return None

        return session_info.session

    def on_files_updated(self, session_id: str) -> None:
        """Event handler for UploadedFileManager.on_file_added.
        Ensures that uploaded files from stale sessions get deleted.
        """
        session_info = self._get_session_info(session_id)
        if session_info is None:
            # If an uploaded file doesn't belong to an existing session,
            # remove it so it doesn't stick around forever.
            self._uploaded_file_mgr.remove_session_files(session_id)

    def _get_session_info(self, session_id: str) -> Optional[SessionInfo]:
        """Return the SessionInfo with the given id, or None if no such
        session exists.

        """
        return self._session_info_by_id.get(session_id, None)

    def start(self, on_started: Callable[["Server"], Any]) -> None:
        """Start the server.

        Parameters
        ----------
        on_started : callable
            A callback that will be called when the server's run-loop
            has started, and the server is ready to begin receiving clients.

        """
        if self._state != State.INITIAL:
            raise RuntimeError("Server has already been started")

        LOGGER.debug("Starting server...")

        app = self._create_app()
        start_listening(app)

        port = config.get_option("server.port")

        LOGGER.debug("Server started on port %s", port)

        self._ioloop.spawn_callback(self._loop_coroutine, on_started)

    def _create_app(self) -> tornado.web.Application:
        """Create our tornado web app."""
        base = config.get_option("server.baseUrlPath")

        routes: List[Any] = [
            (
                make_url_path_regex(base, "stream"),
                _BrowserWebSocketHandler,
                dict(server=self),
            ),
            (
                make_url_path_regex(base, "healthz"),
                HealthHandler,
                dict(callback=lambda: self.is_ready_for_browser_connection),
            ),
            (make_url_path_regex(base, "debugz"), DebugHandler, dict(server=self)),
            (
                make_url_path_regex(base, "message"),
                MessageCacheHandler,
                dict(cache=self._message_cache),
            ),
            (
                make_url_path_regex(base, "st-metrics"),
                StatsHandler,
                dict(stats_manager=self._stats_mgr),
            ),
            (
                make_url_path_regex(
                    base,
                    UPLOAD_FILE_ROUTE,
                ),
                UploadFileRequestHandler,
                dict(
                    file_mgr=self._uploaded_file_mgr,
                    get_session_info=self._get_session_info,
                ),
            ),
            (
                make_url_path_regex(base, "assets/(.*)"),
                AssetsFileHandler,
                {"path": "%s/" % file_util.get_assets_dir()},
            ),
            (make_url_path_regex(base, "media/(.*)"), MediaFileHandler, {"path": ""}),
            (
                make_url_path_regex(base, "component/(.*)"),
                ComponentRequestHandler,
                dict(registry=ComponentRegistry.instance()),
            ),
        ]

        if config.get_option("server.scriptHealthCheckEnabled"):
            routes.extend(
                [
                    (
                        make_url_path_regex(base, "script-health-check"),
                        HealthHandler,
                        dict(callback=lambda: self.does_script_run_without_error()),
                    )
                ]
            )

        if config.get_option("global.developmentMode"):
            LOGGER.debug("Serving static content from the Node dev server")
        else:
            static_path = file_util.get_static_dir()
            LOGGER.debug("Serving static content from %s", static_path)

            routes.extend(
                [
                    (
                        make_url_path_regex(base, "(.*)"),
                        StaticFileHandler,
                        {"path": "%s/" % static_path, "default_filename": "index.html"},
                    ),
                    (make_url_path_regex(base, trailing_slash=False), AddSlashHandler),
                ]
            )

        return tornado.web.Application(
            routes,
            cookie_secret=config.get_option("server.cookieSecret"),
            xsrf_cookies=config.get_option("server.enableXsrfProtection"),
            # Set the websocket message size. The default value is too low.
            websocket_max_message_size=get_max_message_size_bytes(),
            **TORNADO_SETTINGS,  # type: ignore[arg-type]
        )

    def _set_state(self, new_state: State) -> None:
        LOGGER.debug("Server state: %s -> %s" % (self._state, new_state))
        self._state = new_state

    @property
    async def is_ready_for_browser_connection(self) -> Tuple[bool, str]:
        if self._state not in (State.INITIAL, State.STOPPING, State.STOPPED):
            return True, "ok"

        return False, "unavailable"

    async def does_script_run_without_error(self) -> Tuple[bool, str]:
        """Load and execute the app's script to verify it runs without an error.

        Returns
        -------
        (True, "ok") if the script completes without error, or (False, err_msg)
        if the script raises an exception.
        """
        session_data = SessionData(self._script_path, self._command_line)
        local_sources_watcher = LocalSourcesWatcher(session_data)
        session = AppSession(
            ioloop=self._ioloop,
            session_data=session_data,
            uploaded_file_manager=self._uploaded_file_mgr,
            message_enqueued_callback=self._enqueued_some_message,
            local_sources_watcher=local_sources_watcher,
        )

        try:
            session.request_rerun(None)

            now = time.perf_counter()
            while (
                SCRIPT_RUN_WITHOUT_ERRORS_KEY not in session.session_state
                and (time.perf_counter() - now) < SCRIPT_RUN_CHECK_TIMEOUT
            ):
                await tornado.gen.sleep(0.1)

            if SCRIPT_RUN_WITHOUT_ERRORS_KEY not in session.session_state:
                return False, "timeout"

            ok = session.session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY]
            msg = "ok" if ok else "error"

            return ok, msg
        finally:
            session.shutdown()

    @property
    def browser_is_connected(self) -> bool:
        return self._state == State.ONE_OR_MORE_BROWSERS_CONNECTED

    @property
    def is_running_hello(self) -> bool:
        from streamlit.hello import hello

        return self._script_path == hello.__file__

    @tornado.gen.coroutine
    def _loop_coroutine(
        self, on_started: Optional[Callable[["Server"], Any]] = None
    ) -> Generator[Any, None, None]:
        try:
            if self._state == State.INITIAL:
                self._set_state(State.WAITING_FOR_FIRST_BROWSER)
            elif self._state == State.ONE_OR_MORE_BROWSERS_CONNECTED:
                pass
            else:
                raise RuntimeError("Bad server state at start: %s" % self._state)

            if on_started is not None:
                on_started(self)

            while not self._must_stop.is_set():

                if self._state == State.WAITING_FOR_FIRST_BROWSER:
                    yield tornado.gen.convert_yielded(
                        asyncio.wait(
                            [self._must_stop.wait(), self._has_connection.wait()],
                            return_when=asyncio.FIRST_COMPLETED,
                        )
                    )

                elif self._state == State.ONE_OR_MORE_BROWSERS_CONNECTED:
                    self._need_send_data.clear()

                    # Shallow-clone our sessions into a list, so we can iterate
                    # over it and not worry about whether it's being changed
                    # outside this coroutine.
                    session_infos = list(self._session_info_by_id.values())

                    for session_info in session_infos:
                        msg_list = session_info.session.flush_browser_queue()
                        for msg in msg_list:
                            try:
                                self._send_message(session_info, msg)
                            except tornado.websocket.WebSocketClosedError:
                                self._close_app_session(session_info.session.id)
                            yield
                        yield
                    yield tornado.gen.sleep(0.01)

                elif self._state == State.NO_BROWSERS_CONNECTED:
                    yield tornado.gen.convert_yielded(
                        asyncio.wait(
                            [self._must_stop.wait(), self._has_connection.wait()],
                            return_when=asyncio.FIRST_COMPLETED,
                        )
                    )

                else:
                    # Break out of the thread loop if we encounter any other state.
                    break

                yield tornado.gen.convert_yielded(
                    asyncio.wait(
                        [self._must_stop.wait(), self._need_send_data.wait()],
                        return_when=asyncio.FIRST_COMPLETED,
                    )
                )

            # Shut down all AppSessions
            for session_info in list(self._session_info_by_id.values()):
                session_info.session.shutdown()

            self._set_state(State.STOPPED)

        except Exception:
            # Can't just re-raise here because co-routines use Tornado
            # exceptions for control flow, which appears to swallow the reraised
            # exception.
            traceback.print_exc()
            LOGGER.info(
                """
Please report this bug at https://github.com/streamlit/streamlit/issues.
"""
            )

        finally:
            self._on_stopped()

    def _send_message(self, session_info: SessionInfo, msg: ForwardMsg) -> None:
        """Send a message to a client.

        If the client is likely to have already cached the message, we may
        instead send a "reference" message that contains only the hash of the
        message.

        Parameters
        ----------
        session_info : SessionInfo
            The SessionInfo associated with websocket
        msg : ForwardMsg
            The message to send to the client

        """
        msg.metadata.cacheable = is_cacheable_msg(msg)
        msg_to_send = msg
        if msg.metadata.cacheable:
            populate_hash_if_needed(msg)

            if self._message_cache.has_message_reference(
                msg, session_info.session, session_info.script_run_count
            ):
                # This session has probably cached this message. Send
                # a reference instead.
                LOGGER.debug("Sending cached message ref (hash=%s)" % msg.hash)
                msg_to_send = create_reference_msg(msg)

            # Cache the message so it can be referenced in the future.
            # If the message is already cached, this will reset its
            # age.
            LOGGER.debug("Caching message (hash=%s)" % msg.hash)
            self._message_cache.add_message(
                msg, session_info.session, session_info.script_run_count
            )

        # If this was a `script_finished` message, we increment the
        # script_run_count for this session, and update the cache
        if (
            msg.WhichOneof("type") == "script_finished"
            and msg.script_finished == ForwardMsg.FINISHED_SUCCESSFULLY
        ):
            LOGGER.debug(
                "Script run finished successfully; "
                "removing expired entries from MessageCache "
                "(max_age=%s)",
                config.get_option("global.maxCachedMessageAge"),
            )
            session_info.script_run_count += 1
            self._message_cache.remove_expired_session_entries(
                session_info.session, session_info.script_run_count
            )

        # Ship it off!
        session_info.ws.write_message(serialize_forward_msg(msg_to_send), binary=True)

    def _enqueued_some_message(self) -> None:
        self._ioloop.add_callback(self._need_send_data.set)

    def stop(self, from_signal=False) -> None:
        click.secho("  Stopping...", fg="blue")
        self._set_state(State.STOPPING)
        if from_signal:
            self._ioloop.add_callback_from_signal(self._must_stop.set)
        else:
            self._ioloop.add_callback(self._must_stop.set)

    def _on_stopped(self) -> None:
        """Called when our runloop is exiting, to shut down the ioloop.
        This will end our process.

        (Tests can patch this method out, to prevent the test's ioloop
        from being shutdown.)
        """
        self._ioloop.stop()

    def _create_app_session(self, ws: WebSocketHandler) -> AppSession:
        """Register a connected browser with the server.

        Parameters
        ----------
        ws : _BrowserWebSocketHandler
            The newly-connected websocket handler.

        Returns
        -------
        AppSession
            The newly-created AppSession for this browser connection.

        """
        session_data = SessionData(self._script_path, self._command_line)
        local_sources_watcher = LocalSourcesWatcher(session_data)
        session = AppSession(
            ioloop=self._ioloop,
            session_data=session_data,
            uploaded_file_manager=self._uploaded_file_mgr,
            message_enqueued_callback=self._enqueued_some_message,
            local_sources_watcher=local_sources_watcher,
        )

        LOGGER.debug(
            "Created new session for ws %s. Session ID: %s", id(ws), session.id
        )

        assert (
            session.id not in self._session_info_by_id
        ), f"session.id '{session.id}' registered multiple times!"

        self._session_info_by_id[session.id] = SessionInfo(ws, session)
        self._set_state(State.ONE_OR_MORE_BROWSERS_CONNECTED)
        self._has_connection.notify_all()

        return session

    def _close_app_session(self, session_id: str) -> None:
        """Shutdown and remove a AppSession.

        This function may be called multiple times for the same session,
        which is not an error. (Subsequent calls just no-op.)

        Parameters
        ----------
        session_id : str
            The AppSession's id string.
        """
        if session_id in self._session_info_by_id:
            session_info = self._session_info_by_id[session_id]
            del self._session_info_by_id[session_id]
            session_info.session.shutdown()

        if len(self._session_info_by_id) == 0:
            self._set_state(State.NO_BROWSERS_CONNECTED)