Пример #1
0
    def test_cache_stats_provider(self):
        """Test ForwardMsgCache's CacheStatsProvider implementation."""
        cache = ForwardMsgCache()
        session = _create_mock_session()

        # Test empty cache
        self.assertEqual([], cache.get_stats())

        msg1 = _create_dataframe_msg([1, 2, 3])
        populate_hash_if_needed(msg1)
        cache.add_message(msg1, session, 0)

        msg2 = _create_dataframe_msg([5, 4, 3, 2, 1, 0])
        populate_hash_if_needed(msg2)
        cache.add_message(msg2, session, 0)

        # Test cache with messages
        expected = [
            CacheStat(
                category_name="ForwardMessageCache",
                cache_name="",
                byte_length=msg1.ByteSize(),
            ),
            CacheStat(
                category_name="ForwardMessageCache",
                cache_name="",
                byte_length=msg2.ByteSize(),
            ),
        ]
        self.assertEqual(set(expected), set(cache.get_stats()))
Пример #2
0
def serialize_forward_msg(msg):
    """Serialize a ForwardMsg to send to a client.

    If the message is too large, it will be converted to an exception message
    instead.

    Parameters
    ----------
    msg : ForwardMsg
        The message to serialize

    Returns
    -------
    str
        The serialized byte string to send

    """
    populate_hash_if_needed(msg)
    msg_str = msg.SerializeToString()

    if len(msg_str) > MESSAGE_SIZE_LIMIT:
        import streamlit.elements.exception_proto as exception_proto

        error = RuntimeError(
            f"Data of size {len(msg_str)/1e6:.1f}MB exceeds write limit of {MESSAGE_SIZE_LIMIT/1e6}MB"
        )
        # Overwrite the offending ForwardMsg.delta with an error to display.
        # This assumes that the size limit wasn't exceeded due to metadata.
        exception_proto.marshall(msg.delta.new_element.exception, error)
        msg_str = msg.SerializeToString()

    return msg_str
Пример #3
0
    def _send_message(self, session_info: SessionInfo, msg: ForwardMsg) -> None:
        """Send a message to a client.

        If the client is likely to have already cached the message, we may
        instead send a "reference" message that contains only the hash of the
        message.

        Parameters
        ----------
        session_info : SessionInfo
            The SessionInfo associated with websocket
        msg : ForwardMsg
            The message to send to the client

        """
        msg.metadata.cacheable = is_cacheable_msg(msg)
        msg_to_send = msg
        if msg.metadata.cacheable:
            populate_hash_if_needed(msg)

            if self._message_cache.has_message_reference(
                msg, session_info.session, session_info.report_run_count
            ):

                # This session has probably cached this message. Send
                # a reference instead.
                LOGGER.debug("Sending cached message ref (hash=%s)" % msg.hash)
                msg_to_send = create_reference_msg(msg)

            # Cache the message so it can be referenced in the future.
            # If the message is already cached, this will reset its
            # age.
            LOGGER.debug("Caching message (hash=%s)" % msg.hash)
            self._message_cache.add_message(
                msg, session_info.session, session_info.report_run_count
            )

        # If this was a `report_finished` message, we increment the
        # report_run_count for this session, and update the cache
        if (
            msg.WhichOneof("type") == "report_finished"
            and msg.report_finished == ForwardMsg.FINISHED_SUCCESSFULLY
        ):
            LOGGER.debug(
                "Report finished successfully; "
                "removing expired entries from MessageCache "
                "(max_age=%s)",
                config.get_option("global.maxCachedMessageAge"),
            )
            session_info.report_run_count += 1
            self._message_cache.remove_expired_session_entries(
                session_info.session, session_info.report_run_count
            )

        # Ship it off!
        if session_info.ws is not None:
            session_info.ws.write_message(
                serialize_forward_msg(msg_to_send), binary=True
            )
Пример #4
0
    def test_msg_hash(self):
        """Test that ForwardMsg hash generation works as expected"""
        msg1 = _create_dataframe_msg([1, 2, 3])
        msg2 = _create_dataframe_msg([1, 2, 3])
        self.assertEqual(populate_hash_if_needed(msg1),
                         populate_hash_if_needed(msg2))

        msg3 = _create_dataframe_msg([2, 3, 4])
        self.assertNotEqual(populate_hash_if_needed(msg1),
                            populate_hash_if_needed(msg3))
Пример #5
0
    def test_get_message(self):
        """Test MessageCache.get_message"""
        cache = ForwardMsgCache()
        session = _create_mock_session()
        msg = _create_dataframe_msg([1, 2, 3])

        msg_hash = populate_hash_if_needed(msg)

        cache.add_message(msg, session, 0)
        self.assertEqual(msg, cache.get_message(msg_hash))
Пример #6
0
def serialize_forward_msg(msg: ForwardMsg) -> bytes:
    """Serialize a ForwardMsg to send to a client.

    If the message is too large, it will be converted to an exception message
    instead.
    """
    populate_hash_if_needed(msg)
    msg_str = msg.SerializeToString()

    if len(msg_str) > get_max_message_size_bytes():
        import streamlit.elements.exception as exception

        # Overwrite the offending ForwardMsg.delta with an error to display.
        # This assumes that the size limit wasn't exceeded due to metadata.
        exception.marshall(msg.delta.new_element.exception,
                           MessageSizeError(msg_str))
        msg_str = msg.SerializeToString()

    return msg_str
Пример #7
0
    def test_message_cache(self):
        # Create a new ForwardMsg and cache it
        msg = _create_dataframe_msg([1, 2, 3])
        msg_hash = populate_hash_if_needed(msg)
        self._cache.add_message(msg, MagicMock(), 0)

        # Cache hit
        response = self.fetch("/message?hash=%s" % msg_hash)
        self.assertEqual(200, response.code)
        self.assertEqual(serialize_forward_msg(msg), response.body)

        # Cache misses
        self.assertEqual(404, self.fetch("/message").code)
        self.assertEqual(404, self.fetch("/message?id=non_existent").code)
Пример #8
0
    def test_forwardmsg_hashing(self):
        """Test that outgoing ForwardMsgs contain hashes."""
        with self._patch_report_session():
            yield self.start_server_loop()

            ws_client = yield self.ws_connect()

            # Get the server's socket and session for this client
            session_info = list(self.server._session_info_by_id.values())[0]

            # Create a message and ensure its hash is unset; we're testing
            # that _send_message adds the hash before it goes out.
            msg = _create_dataframe_msg([1, 2, 3])
            msg.ClearField("hash")
            self.server._send_message(session_info, msg)

            received = yield self.read_forward_msg(ws_client)
            self.assertEqual(populate_hash_if_needed(msg), received.hash)
Пример #9
0
    def test_message_expiration(self):
        """Test MessageCache's expiration logic"""
        config._set_option("global.maxCachedMessageAge", 1, "test")

        cache = ForwardMsgCache()
        session1 = _create_mock_session()
        runcount1 = 0

        msg = _create_dataframe_msg([1, 2, 3])
        msg_hash = populate_hash_if_needed(msg)

        cache.add_message(msg, session1, runcount1)

        # Increment session1's run_count. This should not resolve in expiry.
        runcount1 += 1
        self.assertTrue(cache.has_message_reference(msg, session1, runcount1))

        # Increment again. The message will now be expired for session1,
        # though it won't have actually been removed yet.
        runcount1 += 1
        self.assertFalse(cache.has_message_reference(msg, session1, runcount1))
        self.assertIsNotNone(cache.get_message(msg_hash))

        # Add another reference to the message
        session2 = _create_mock_session()
        runcount2 = 0
        cache.add_message(msg, session2, runcount2)

        # Remove session1's expired entries. This should not remove the
        # entry from the cache, because session2 still has a reference to it.
        cache.remove_expired_session_entries(session1, runcount1)
        self.assertFalse(cache.has_message_reference(msg, session1, runcount1))
        self.assertTrue(cache.has_message_reference(msg, session2, runcount2))

        # Expire session2's reference. The message should no longer be
        # in the cache at all.
        runcount2 += 2
        cache.remove_expired_session_entries(session2, runcount2)
        self.assertIsNone(cache.get_message(msg_hash))
Пример #10
0
 def test_reference_msg(self):
     """Test creation of 'reference' ForwardMsgs"""
     msg = _create_dataframe_msg([1, 2, 3], 34)
     ref_msg = create_reference_msg(msg)
     self.assertEqual(populate_hash_if_needed(msg), ref_msg.ref_hash)
     self.assertEqual(msg.metadata, ref_msg.metadata)
Пример #11
0
 def test_delta_metadata(self):
     """Test that delta metadata doesn't change the hash"""
     msg1 = _create_dataframe_msg([1, 2, 3], 1)
     msg2 = _create_dataframe_msg([1, 2, 3], 2)
     self.assertEqual(populate_hash_if_needed(msg1),
                      populate_hash_if_needed(msg2))