Пример #1
0
    def get(self):
        msg_hash = self.get_argument("hash", None)
        if msg_hash is None:
            # Hash is missing! This is a malformed request.
            LOGGER.error(
                "HTTP request for cached message is " "missing the hash attribute."
            )
            self.set_status(404)
            raise tornado.web.Finish()

        message = self._cache.get_message(msg_hash)
        if message is None:
            # Message not in our cache.
            LOGGER.error(
                "HTTP request for cached message could not be fulfilled. "
                "No such message: %s" % msg_hash
            )
            self.set_status(404)
            raise tornado.web.Finish()

        LOGGER.debug("MessageCache HIT [hash=%s]" % msg_hash)
        msg_str = serialize_forward_msg(message)
        self.set_header("Content-Type", "application/octet-stream")
        self.write(msg_str)
        self.set_status(200)
Пример #2
0
    def _send_message(self, session_info: SessionInfo, msg: ForwardMsg) -> None:
        """Send a message to a client.

        If the client is likely to have already cached the message, we may
        instead send a "reference" message that contains only the hash of the
        message.

        Parameters
        ----------
        session_info : SessionInfo
            The SessionInfo associated with websocket
        msg : ForwardMsg
            The message to send to the client

        """
        msg.metadata.cacheable = is_cacheable_msg(msg)
        msg_to_send = msg
        if msg.metadata.cacheable:
            populate_hash_if_needed(msg)

            if self._message_cache.has_message_reference(
                msg, session_info.session, session_info.report_run_count
            ):

                # This session has probably cached this message. Send
                # a reference instead.
                LOGGER.debug("Sending cached message ref (hash=%s)" % msg.hash)
                msg_to_send = create_reference_msg(msg)

            # Cache the message so it can be referenced in the future.
            # If the message is already cached, this will reset its
            # age.
            LOGGER.debug("Caching message (hash=%s)" % msg.hash)
            self._message_cache.add_message(
                msg, session_info.session, session_info.report_run_count
            )

        # If this was a `report_finished` message, we increment the
        # report_run_count for this session, and update the cache
        if (
            msg.WhichOneof("type") == "report_finished"
            and msg.report_finished == ForwardMsg.FINISHED_SUCCESSFULLY
        ):
            LOGGER.debug(
                "Report finished successfully; "
                "removing expired entries from MessageCache "
                "(max_age=%s)",
                config.get_option("global.maxCachedMessageAge"),
            )
            session_info.report_run_count += 1
            self._message_cache.remove_expired_session_entries(
                session_info.session, session_info.report_run_count
            )

        # Ship it off!
        if session_info.ws is not None:
            session_info.ws.write_message(
                serialize_forward_msg(msg_to_send), binary=True
            )
Пример #3
0
    def handle_save_request(self, ws):
        """Save serialized version of report deltas to the cloud.

        "Progress" ForwardMsgs will be sent to the client during the upload.
        These messages are sent "out of band" - that is, they don't get
        enqueued into the ReportQueue (because they're not part of the report).
        Instead, they're written directly to the report's WebSocket.

        Parameters
        ----------
        ws : _BrowserWebSocketHandler
            The report's websocket handler.

        """
        @tornado.gen.coroutine
        def progress(percent):
            progress_msg = ForwardMsg()
            progress_msg.upload_report_progress = percent
            yield ws.write_message(serialize_forward_msg(progress_msg),
                                   binary=True)

        # Indicate that the save is starting.
        try:
            yield progress(0)

            url = yield self._save_final_report(progress)

            # Indicate that the save is done.
            progress_msg = ForwardMsg()
            progress_msg.report_uploaded = url
            yield ws.write_message(serialize_forward_msg(progress_msg),
                                   binary=True)

        except Exception as e:
            # Horrible hack to show something if something breaks.
            err_msg = "%s: %s" % (type(e).__name__, str(e)
                                  or "No further details.")
            progress_msg = ForwardMsg()
            progress_msg.report_uploaded = err_msg
            yield ws.write_message(serialize_forward_msg(progress_msg),
                                   binary=True)

            LOGGER.warning("Failed to save report:", exc_info=e)
Пример #4
0
    def _loop_coroutine(self, on_started=None):
        if self._state == State.INITIAL:
            self._set_state(State.WAITING_FOR_FIRST_BROWSER)
        elif self._state == State.ONE_OR_MORE_BROWSERS_CONNECTED:
            pass
        else:
            raise RuntimeError("Bad server state at start: %s" % self._state)

        if on_started is not None:
            on_started(self)

        while not self._must_stop.is_set():
            if self._state == State.WAITING_FOR_FIRST_BROWSER:
                pass

            elif self._state == State.ONE_OR_MORE_BROWSERS_CONNECTED:

                # Shallow-clone our sessions into a list, so we can iterate
                # over it and not worry about whether it's being changed
                # outside this coroutine.
                ws_session_pairs = list(self._report_sessions.items())

                for ws, session in ws_session_pairs:
                    if ws is PREHEATED_REPORT_SESSION:
                        continue
                    if ws is None:
                        continue
                    msg_list = session.flush_browser_queue()
                    for msg in msg_list:
                        msg_str = serialize_forward_msg(msg)
                        try:
                            ws.write_message(msg_str, binary=True)
                        except tornado.websocket.WebSocketClosedError:
                            self._remove_browser_connection(ws)
                        yield
                    yield

            elif self._state == State.NO_BROWSERS_CONNECTED:
                pass

            else:
                # Break out of the thread loop if we encounter any other state.
                break

            yield tornado.gen.sleep(0.01)

        # Shut down all ReportSessions
        for session in list(self._report_sessions.values()):
            session.shutdown()

        self._set_state(State.STOPPED)

        self._on_stopped()
Пример #5
0
    def test_message_cache(self):
        # Create a new ForwardMsg and cache it
        msg = _create_dataframe_msg([1, 2, 3])
        msg_hash = populate_hash_if_needed(msg)
        self._cache.add_message(msg, MagicMock(), 0)

        # Cache hit
        response = self.fetch("/message?hash=%s" % msg_hash)
        self.assertEqual(200, response.code)
        self.assertEqual(serialize_forward_msg(msg), response.body)

        # Cache misses
        self.assertEqual(404, self.fetch("/message").code)
        self.assertEqual(404, self.fetch("/message?id=non_existent").code)
Пример #6
0
    def test_should_limit_msg_size(self):
        # Set up a 60MB ForwardMsg string
        large_msg = _create_dataframe_msg([1, 2, 3])
        large_msg.delta.new_element.markdown.body = "X" * 60 * 1000 * 1000
        # Create a copy, since serialize_forward_msg modifies the original proto
        large_msg_copy = ForwardMsg()
        large_msg_copy.CopyFrom(large_msg)
        deserialized_msg = ForwardMsg()
        deserialized_msg.ParseFromString(serialize_forward_msg(large_msg_copy))

        # The metadata should be the same, but contents should be replaced
        self.assertEqual(deserialized_msg.metadata, large_msg.metadata)
        self.assertNotEqual(deserialized_msg, large_msg)
        expected = "Data of size 60.0MB exceeds write limit of 50.0MB"
        self.assertEqual(deserialized_msg.delta.new_element.exception.message, expected)
Пример #7
0
    def test_should_limit_msg_size(self):
        max_message_size_mb = 50
        # Set max message size to defined value
        from streamlit.server import server_util

        server_util._max_message_size_bytes = None  # Reset cached value
        config._set_option("server.maxMessageSize", max_message_size_mb,
                           "test")

        # Set up a larger than limit ForwardMsg string
        large_msg = _create_dataframe_msg([1, 2, 3])
        large_msg.delta.new_element.markdown.body = (
            "X" * (max_message_size_mb + 10) * 1000 * 1000)
        # Create a copy, since serialize_forward_msg modifies the original proto
        large_msg_copy = ForwardMsg()
        large_msg_copy.CopyFrom(large_msg)
        deserialized_msg = ForwardMsg()
        deserialized_msg.ParseFromString(serialize_forward_msg(large_msg_copy))

        # The metadata should be the same, but contents should be replaced
        self.assertEqual(deserialized_msg.metadata, large_msg.metadata)
        self.assertNotEqual(deserialized_msg, large_msg)
        self.assertTrue("exceeds the message size limit" in
                        deserialized_msg.delta.new_element.exception.message)
Пример #8
0
 def progress(percent):
     progress_msg = ForwardMsg()
     progress_msg.upload_report_progress = percent
     yield ws.write_message(serialize_forward_msg(progress_msg),
                            binary=True)