Example #1
0
    def test_add_message(self):
        """Test MessageCache.add_message and has_message_reference"""
        cache = ForwardMsgCache()
        session = _create_mock_session()
        msg = _create_dataframe_msg([1, 2, 3])
        cache.add_message(msg, session, 0)

        self.assertTrue(cache.has_message_reference(msg, session, 0))
        self.assertFalse(cache.has_message_reference(msg, _create_mock_session(), 0))
    def test_get_message(self):
        """Test MessageCache.get_message"""
        cache = ForwardMsgCache()
        session = _create_mock_session()
        msg = _create_dataframe_msg([1, 2, 3])

        msg_hash = populate_hash_if_needed(msg)

        cache.add_message(msg, session, 0)
        self.assertEqual(msg, cache.get_message(msg_hash))
Example #3
0
class MessageCacheHandlerTest(tornado.testing.AsyncHTTPTestCase):
    def get_app(self):
        self._cache = ForwardMsgCache()
        return tornado.web.Application([(r"/message", MessageCacheHandler,
                                         dict(cache=self._cache))])

    def test_message_cache(self):
        # Create a new ForwardMsg and cache it
        msg = _create_dataframe_msg([1, 2, 3])
        msg_hash = populate_hash_if_needed(msg)
        self._cache.add_message(msg, MagicMock(), 0)

        # Cache hit
        response = self.fetch("/message?hash=%s" % msg_hash)
        self.assertEqual(200, response.code)
        self.assertEqual(serialize_forward_msg(msg), response.body)

        # Cache misses
        self.assertEqual(404, self.fetch("/message").code)
        self.assertEqual(404, self.fetch("/message?id=non_existent").code)
    def test_message_expiration(self):
        """Test MessageCache's expiration logic"""
        config._set_option("global.maxCachedMessageAge", 1, "test")

        cache = ForwardMsgCache()
        session1 = _create_mock_session()
        runcount1 = 0

        msg = _create_dataframe_msg([1, 2, 3])
        msg_hash = populate_hash_if_needed(msg)

        cache.add_message(msg, session1, runcount1)

        # Increment session1's run_count. This should not resolve in expiry.
        runcount1 += 1
        self.assertTrue(cache.has_message_reference(msg, session1, runcount1))

        # Increment again. The message will now be expired for session1,
        # though it won't have actually been removed yet.
        runcount1 += 1
        self.assertFalse(cache.has_message_reference(msg, session1, runcount1))
        self.assertIsNotNone(cache.get_message(msg_hash))

        # Add another reference to the message
        session2 = _create_mock_session()
        runcount2 = 0
        cache.add_message(msg, session2, runcount2)

        # Remove session1's expired entries. This should not remove the
        # entry from the cache, because session2 still has a reference to it.
        cache.remove_expired_session_entries(session1, runcount1)
        self.assertFalse(cache.has_message_reference(msg, session1, runcount1))
        self.assertTrue(cache.has_message_reference(msg, session2, runcount2))

        # Expire session2's reference. The message should no longer be
        # in the cache at all.
        runcount2 += 2
        cache.remove_expired_session_entries(session2, runcount2)
        self.assertIsNone(cache.get_message(msg_hash))
Example #5
0
class Server(object):

    _singleton = None  # type: Optional[Server]

    @classmethod
    def get_current(cls):
        """
        Returns
        -------
        Server
            The singleton Server object.
        """
        if cls._singleton is None:
            raise RuntimeError("Server has not been initialized yet")

        return Server._singleton

    def __init__(self, ioloop, script_path, command_line):
        """Create the server. It won't be started yet.

        Parameters
        ----------
        ioloop : tornado.ioloop.IOLoop
        script_path : str
        command_line : str

        """
        if Server._singleton is not None:
            raise RuntimeError(
                "Server already initialized. Use .get_current() instead")

        Server._singleton = self

        _set_tornado_log_levels()

        self._ioloop = ioloop
        self._script_path = script_path
        self._command_line = command_line

        # Mapping of ReportSession.id -> SessionInfo.
        self._session_info_by_id = {}

        self._must_stop = threading.Event()
        self._state = None
        self._set_state(State.INITIAL)
        self._message_cache = ForwardMsgCache()
        self._uploaded_file_mgr = UploadedFileManager()
        self._uploaded_file_mgr.on_files_added.connect(self._on_file_uploaded)
        self._report = None  # type: Optional[Report]

    def _on_file_uploaded(self, file):
        """Event handler for UploadedFileManager.on_file_added.

        When a file is uploaded by a user, schedule a re-run of the
        corresponding ReportSession.

        Parameters
        ----------
        file : File
            The file that was just uploaded.

        """
        session_info = self._get_session_info(file.session_id)
        if session_info is not None:
            session_info.session.request_rerun()
        else:
            # If an uploaded file doesn't belong to an existing session,
            # remove it so it doesn't stick around forever.
            self._uploaded_file_mgr.remove_files(file.session_id,
                                                 file.widget_id)

    def _get_session_info(self, session_id):
        """Return the SessionInfo with the given id, or None if no such
        session exists.

        Parameters
        ----------
        session_id : str

        Returns
        -------
        SessionInfo or None

        """
        return self._session_info_by_id.get(session_id, None)

    def start(self, on_started):
        """Start the server.

        Parameters
        ----------
        on_started : callable
            A callback that will be called when the server's run-loop
            has started, and the server is ready to begin receiving clients.

        """
        if self._state != State.INITIAL:
            raise RuntimeError("Server has already been started")

        LOGGER.debug("Starting server...")

        app = self._create_app()
        start_listening(app)

        port = config.get_option("server.port")

        LOGGER.debug("Server started on port %s", port)

        self._ioloop.spawn_callback(self._loop_coroutine, on_started)

    def get_debug(self) -> Dict[str, Dict[str, Any]]:
        if self._report:
            return {"report": self._report.get_debug()}
        return {}

    def _create_app(self):
        """Create our tornado web app.

        Returns
        -------
        tornado.web.Application

        """
        base = config.get_option("server.baseUrlPath")
        routes = [
            (
                make_url_path_regex(base, "stream"),
                _BrowserWebSocketHandler,
                dict(server=self),
            ),
            (
                make_url_path_regex(base, "healthz"),
                HealthHandler,
                dict(callback=lambda: self.is_ready_for_browser_connection),
            ),
            (make_url_path_regex(base,
                                 "debugz"), DebugHandler, dict(server=self)),
            (make_url_path_regex(base, "metrics"), MetricsHandler),
            (
                make_url_path_regex(base, "message"),
                MessageCacheHandler,
                dict(cache=self._message_cache),
            ),
            (
                make_url_path_regex(base, "upload_file"),
                UploadFileRequestHandler,
                dict(file_mgr=self._uploaded_file_mgr),
            ),
            (make_url_path_regex(base, "media/(.*)"), MediaFileHandler),
        ]

        if config.get_option("global.developmentMode") and config.get_option(
                "global.useNode"):
            LOGGER.debug("Serving static content from the Node dev server")
        else:
            static_path = file_util.get_static_dir()
            LOGGER.debug("Serving static content from %s", static_path)

            routes.extend([
                (
                    make_url_path_regex(base, "(.*)"),
                    StaticFileHandler,
                    {
                        "path": "%s/" % static_path,
                        "default_filename": "index.html"
                    },
                ),
                (make_url_path_regex(base,
                                     trailing_slash=False), AddSlashHandler),
            ])

        return tornado.web.Application(routes, **TORNADO_SETTINGS)

    def _set_state(self, new_state):
        LOGGER.debug("Server state: %s -> %s" % (self._state, new_state))
        self._state = new_state

    @property
    def is_ready_for_browser_connection(self):
        return self._state not in (State.INITIAL, State.STOPPING,
                                   State.STOPPED)

    @property
    def browser_is_connected(self):
        return self._state == State.ONE_OR_MORE_BROWSERS_CONNECTED

    @tornado.gen.coroutine
    def _loop_coroutine(self, on_started=None):
        try:
            if self._state == State.INITIAL:
                self._set_state(State.WAITING_FOR_FIRST_BROWSER)
            elif self._state == State.ONE_OR_MORE_BROWSERS_CONNECTED:
                pass
            else:
                raise RuntimeError("Bad server state at start: %s" %
                                   self._state)

            if on_started is not None:
                on_started(self)

            while not self._must_stop.is_set():

                if self._state == State.WAITING_FOR_FIRST_BROWSER:
                    pass

                elif self._state == State.ONE_OR_MORE_BROWSERS_CONNECTED:

                    # Shallow-clone our sessions into a list, so we can iterate
                    # over it and not worry about whether it's being changed
                    # outside this coroutine.
                    session_infos = list(self._session_info_by_id.values())

                    for session_info in session_infos:
                        if session_info.ws is None:
                            # Preheated.
                            continue
                        msg_list = session_info.session.flush_browser_queue()
                        for msg in msg_list:
                            try:
                                self._send_message(session_info, msg)
                            except tornado.websocket.WebSocketClosedError:
                                self._close_report_session(
                                    session_info.session.id)
                            yield
                        yield

                elif self._state == State.NO_BROWSERS_CONNECTED:
                    pass

                else:
                    # Break out of the thread loop if we encounter any other state.
                    break

                yield tornado.gen.sleep(0.01)

            # Shut down all ReportSessions
            for session_info in list(self._session_info_by_id.values()):
                session_info.session.shutdown()

            self._set_state(State.STOPPED)

        except Exception as e:
            print("EXCEPTION!", e)
            traceback.print_stack(file=sys.stdout)
            LOGGER.info("""
Please report this bug at https://github.com/streamlit/streamlit/issues.
""")

        finally:
            self._on_stopped()

    def _send_message(self, session_info, msg):
        """Send a message to a client.

        If the client is likely to have already cached the message, we may
        instead send a "reference" message that contains only the hash of the
        message.

        Parameters
        ----------
        session_info : SessionInfo
            The SessionInfo associated with websocket
        msg : ForwardMsg
            The message to send to the client

        """
        msg.metadata.cacheable = is_cacheable_msg(msg)
        msg_to_send = msg
        if msg.metadata.cacheable:
            populate_hash_if_needed(msg)

            if self._message_cache.has_message_reference(
                    msg, session_info.session, session_info.report_run_count):

                # This session has probably cached this message. Send
                # a reference instead.
                LOGGER.debug("Sending cached message ref (hash=%s)" % msg.hash)
                msg_to_send = create_reference_msg(msg)

            # Cache the message so it can be referenced in the future.
            # If the message is already cached, this will reset its
            # age.
            LOGGER.debug("Caching message (hash=%s)" % msg.hash)
            self._message_cache.add_message(msg, session_info.session,
                                            session_info.report_run_count)

        # If this was a `report_finished` message, we increment the
        # report_run_count for this session, and update the cache
        if (msg.WhichOneof("type") == "report_finished"
                and msg.report_finished == ForwardMsg.FINISHED_SUCCESSFULLY):
            LOGGER.debug(
                "Report finished successfully; "
                "removing expired entries from MessageCache "
                "(max_age=%s)",
                config.get_option("global.maxCachedMessageAge"),
            )
            session_info.report_run_count += 1
            self._message_cache.remove_expired_session_entries(
                session_info.session, session_info.report_run_count)

        # Ship it off!
        session_info.ws.write_message(serialize_forward_msg(msg_to_send),
                                      binary=True)

    def stop(self):
        click.secho("  Stopping...", fg="blue")
        self._set_state(State.STOPPING)
        self._must_stop.set()

    def _on_stopped(self):
        """Called when our runloop is exiting, to shut down the ioloop.
        This will end our process.

        (Tests can patch this method out, to prevent the test's ioloop
        from being shutdown.)
        """
        self._ioloop.stop()

    def add_preheated_report_session(self):
        """Register a fake browser with the server and run the script.

        This is used to start running the user's script even before the first
        browser connects.
        """
        session = self._create_report_session(ws=None)
        session.handle_rerun_script_request(is_preheat=True)

    def _create_report_session(self, ws):
        """Register a connected browser with the server.

        Parameters
        ----------
        ws : _BrowserWebSocketHandler or None
            The newly-connected websocket handler or None if preheated
            connection.

        Returns
        -------
        ReportSession
            The newly-created ReportSession for this browser connection.

        """
        if PREHEATED_ID in self._session_info_by_id:
            assert len(self._session_info_by_id) == 1
            LOGGER.debug("Reusing preheated context for ws %s", ws)
            session = self._session_info_by_id[PREHEATED_ID].session
            del self._session_info_by_id[PREHEATED_ID]
            session.id = 0
        else:
            LOGGER.debug("Creating new context for ws %s", ws)
            session = ReportSession(
                is_preheat=(ws is None),
                ioloop=self._ioloop,
                script_path=self._script_path,
                command_line=self._command_line,
                uploaded_file_manager=self._uploaded_file_mgr,
            )

        assert session.id not in self._session_info_by_id, (
            "session.id '%s' registered multiple times!" % session.id)

        self._session_info_by_id[session.id] = SessionInfo(ws, session)

        if ws is not None:
            self._set_state(State.ONE_OR_MORE_BROWSERS_CONNECTED)

        return session

    def _close_report_session(self, session_id):
        """Shutdown and remove a ReportSession.

        This function may be called multiple times for the same session,
        which is not an error. (Subsequent calls just no-op.)

        Parameters
        ----------
        session_id : str
            The ReportSession's id string.
        """
        if session_id in self._session_info_by_id:
            session_info = self._session_info_by_id[session_id]
            del self._session_info_by_id[session_id]
            session_info.session.shutdown()

        if len(self._session_info_by_id) == 0:
            self._set_state(State.NO_BROWSERS_CONNECTED)
Example #6
0
class Server(object):

    _singleton = None

    @classmethod
    def get_current(cls):
        """Return the singleton instance."""
        if cls._singleton is None:
            raise RuntimeError("Server has not been initialized yet")

        return Server._singleton

    def __init__(self, ioloop, script_path, command_line):
        """Create the server. It won't be started yet.

        Parameters
        ----------
        ioloop : tornado.ioloop.IOLoop
        script_path : str
        command_line : str

        """
        if Server._singleton is not None:
            raise RuntimeError(
                "Server already initialized. Use .get_current() instead")

        Server._singleton = self

        _set_tornado_log_levels()

        self._ioloop = ioloop
        self._script_path = script_path
        self._command_line = command_line

        # Mapping of WebSocket->SessionInfo.
        self._session_infos = {}

        self._must_stop = threading.Event()
        self._state = None
        self._set_state(State.INITIAL)
        self._message_cache = ForwardMsgCache()

    def start(self, on_started):
        """Start the server.

        Parameters
        ----------
        on_started : callable
            A callback that will be called when the server's run-loop
            has started, and the server is ready to begin receiving clients.

        """
        if self._state != State.INITIAL:
            raise RuntimeError("Server has already been started")

        LOGGER.debug("Starting server...")

        app = self._create_app()
        start_listening(app)

        port = config.get_option("server.port")

        LOGGER.debug("Server started on port %s", port)

        self._ioloop.spawn_callback(self._loop_coroutine, on_started)

    def get_debug(self):
        return {"report": self._report.get_debug()}

    def _create_app(self):
        """Create our tornado web app.

        Returns
        -------
        tornado.web.Application

        """
        routes = [
            (r"/stream", _BrowserWebSocketHandler, dict(server=self)),
            (
                r"/healthz",
                HealthHandler,
                dict(
                    health_check=lambda: self.is_ready_for_browser_connection),
            ),
            (r"/debugz", DebugHandler, dict(server=self)),
            (r"/metrics", MetricsHandler),
            (r"/message", MessageCacheHandler,
             dict(cache=self._message_cache)),
        ]

        if config.get_option("global.developmentMode") and config.get_option(
                "global.useNode"):
            LOGGER.debug("Serving static content from the Node dev server")
        else:
            static_path = util.get_static_dir()
            LOGGER.debug("Serving static content from %s", static_path)

            routes.extend([
                (
                    r"/()$",
                    StaticFileHandler,
                    {
                        "path": "%s/index.html" % static_path
                    },
                ),
                (r"/(.*)", StaticFileHandler, {
                    "path": "%s/" % static_path
                }),
            ])

        return tornado.web.Application(routes, **TORNADO_SETTINGS)

    def _set_state(self, new_state):
        LOGGER.debug("Server state: %s -> %s" % (self._state, new_state))
        self._state = new_state

    @property
    def is_ready_for_browser_connection(self):
        return self._state not in (State.INITIAL, State.STOPPING,
                                   State.STOPPED)

    @property
    def browser_is_connected(self):
        return self._state == State.ONE_OR_MORE_BROWSERS_CONNECTED

    @tornado.gen.coroutine
    def _loop_coroutine(self, on_started=None):
        if self._state == State.INITIAL:
            self._set_state(State.WAITING_FOR_FIRST_BROWSER)
        elif self._state == State.ONE_OR_MORE_BROWSERS_CONNECTED:
            pass
        else:
            raise RuntimeError("Bad server state at start: %s" % self._state)

        if on_started is not None:
            on_started(self)

        while not self._must_stop.is_set():
            if self._state == State.WAITING_FOR_FIRST_BROWSER:
                pass

            elif self._state == State.ONE_OR_MORE_BROWSERS_CONNECTED:

                # Shallow-clone our sessions into a list, so we can iterate
                # over it and not worry about whether it's being changed
                # outside this coroutine.
                session_pairs = list(self._session_infos.items())

                for ws, session_info in session_pairs:
                    if ws is PREHEATED_REPORT_SESSION:
                        continue
                    if ws is None:
                        continue
                    msg_list = session_info.session.flush_browser_queue()
                    for msg in msg_list:
                        try:
                            self._send_message(ws, session_info, msg)
                        except tornado.websocket.WebSocketClosedError:
                            self._remove_browser_connection(ws)
                        yield
                    yield

            elif self._state == State.NO_BROWSERS_CONNECTED:
                pass

            else:
                # Break out of the thread loop if we encounter any other state.
                break

            yield tornado.gen.sleep(0.01)

        # Shut down all ReportSessions
        for session_info in list(self._session_infos.values()):
            session_info.session.shutdown()

        self._set_state(State.STOPPED)

        self._on_stopped()

    def _send_message(self, ws, session_info, msg):
        """Send a message to a client.

        If the client is likely to have already cached the message, we may
        instead send a "reference" message that contains only the hash of the
        message.

        Parameters
        ----------
        ws : _BrowserWebSocketHandler
            The socket connected to the client
        session_info : SessionInfo
            The SessionInfo associated with websocket
        msg : ForwardMsg
            The message to send to the client

        """
        msg.metadata.cacheable = is_cacheable_msg(msg)
        msg_to_send = msg
        if msg.metadata.cacheable:
            populate_hash_if_needed(msg)

            if self._message_cache.has_message_reference(
                    msg, session_info.session, session_info.report_run_count):

                # This session has probably cached this message. Send
                # a reference instead.
                LOGGER.debug("Sending cached message ref (hash=%s)" % msg.hash)
                msg_to_send = create_reference_msg(msg)

            # Cache the message so it can be referenced in the future.
            # If the message is already cached, this will reset its
            # age.
            LOGGER.debug("Caching message (hash=%s)" % msg.hash)
            self._message_cache.add_message(msg, session_info.session,
                                            session_info.report_run_count)

        # If this was a `report_finished` message, we increment the
        # report_run_count for this session, and update the cache
        if (msg.WhichOneof("type") == "report_finished"
                and msg.report_finished == ForwardMsg.FINISHED_SUCCESSFULLY):
            LOGGER.debug(
                "Report finished successfully; "
                "removing expired entries from MessageCache "
                "(max_age=%s)",
                config.get_option("global.maxCachedMessageAge"),
            )
            session_info.report_run_count += 1
            self._message_cache.remove_expired_session_entries(
                session_info.session, session_info.report_run_count)

        # Ship it off!
        ws.write_message(serialize_forward_msg(msg_to_send), binary=True)

    def stop(self):
        self._set_state(State.STOPPING)
        self._must_stop.set()

    def _on_stopped(self):
        """Called when our runloop is exiting, to shut down the ioloop.
        This will end our process.

        (Tests can patch this method out, to prevent the test's ioloop
        from being shutdown.)
        """
        self._ioloop.stop()

    def add_preheated_report_session(self):
        """Register a fake browser with the server and run the script.

        This is used to start running the user's script even before the first
        browser connects.
        """
        session = self._add_browser_connection(PREHEATED_REPORT_SESSION)
        session.handle_rerun_script_request(is_preheat=True)

    def _add_browser_connection(self, ws):
        """Register a connected browser with the server

        Parameters
        ----------
        ws : _BrowserWebSocketHandler or PREHEATED_REPORT_CONTEXT
            The newly-connected websocket handler

        Returns
        -------
        ReportSession
            The ReportSession associated with this browser connection

        """
        if ws not in self._session_infos:

            if PREHEATED_REPORT_SESSION in self._session_infos:
                assert len(self._session_infos) == 1
                LOGGER.debug("Reusing preheated context for ws %s", ws)
                session = self._session_infos[PREHEATED_REPORT_SESSION].session
                del self._session_infos[PREHEATED_REPORT_SESSION]
            else:
                LOGGER.debug("Creating new context for ws %s", ws)
                session = ReportSession(
                    ioloop=self._ioloop,
                    script_path=self._script_path,
                    command_line=self._command_line,
                )

            self._session_infos[ws] = SessionInfo(session)

            if ws is not PREHEATED_REPORT_SESSION:
                self._set_state(State.ONE_OR_MORE_BROWSERS_CONNECTED)

        return self._session_infos[ws].session

    def _remove_browser_connection(self, ws):
        if ws in self._session_infos:
            session_info = self._session_infos[ws]
            del self._session_infos[ws]
            session_info.session.shutdown()

        if len(self._session_infos) == 0:
            self._set_state(State.NO_BROWSERS_CONNECTED)