def _enqueue_script_finished_message(self, status):
        """Enqueue a script_finished ForwardMsg.

        Parameters
        ----------
        status : ScriptFinishedStatus

        """
        msg = ForwardMsg()
        msg.script_finished = status
        self.enqueue(msg)
    def enqueue(self, msg: ForwardMsg) -> None:
        if msg.HasField(
                "page_config_changed") and not self._set_page_config_allowed:
            raise StreamlitAPIException(
                "`set_page_config()` can only be called once per app, " +
                "and must be called as the first Streamlit command in your script.\n\n"
                + "For more information refer to the [docs]" +
                "(https://docs.streamlit.io/en/stable/api.html#streamlit.set_page_config)."
            )

        if msg.HasField("delta") or msg.HasField("page_config_changed"):
            self._set_page_config_allowed = False

        self._enqueue(msg)
예제 #3
0
    def handle_save_request(self, ws):
        """Save serialized version of report deltas to the cloud."""
        @tornado.gen.coroutine
        def progress(percent):
            progress_msg = ForwardMsg()
            progress_msg.upload_report_progress = percent
            yield ws.write_message(progress_msg.SerializeToString(),
                                   binary=True)

        # Indicate that the save is starting.
        try:
            yield progress(0)

            url = yield self._save_final_report(progress)

            # Indicate that the save is done.
            progress_msg = ForwardMsg()
            progress_msg.report_uploaded = url
            yield ws.write_message(progress_msg.SerializeToString(),
                                   binary=True)

        except Exception as e:
            # Horrible hack to show something if something breaks.
            err_msg = '%s: %s' % (type(e).__name__, str(e)
                                  or 'No further details.')
            progress_msg = ForwardMsg()
            progress_msg.report_uploaded = err_msg
            yield ws.write_message(progress_msg.SerializeToString(),
                                   binary=True)
            raise e
예제 #4
0
    def test_set_page_config_first(self, _1):
        """st.set_page_config must be called before other st commands"""
        file_mgr = MagicMock(spec=UploadedFileManager)
        rs = ReportSession(None, "", "", file_mgr)

        markdown_msg = ForwardMsg()
        markdown_msg.delta.new_element.markdown.body = "foo"

        msg = ForwardMsg()
        msg.page_config_changed.title = "foo"

        rs.enqueue(markdown_msg)
        with self.assertRaises(StreamlitAPIException):
            rs.enqueue(msg)
예제 #5
0
    def test_set_page_config_first(self):
        """st.set_page_config must be called before other st commands"""

        fake_enqueue = lambda msg: None
        ctx = ReportContext("TestSessionID", fake_enqueue, "", None, None, None)

        markdown_msg = ForwardMsg()
        markdown_msg.delta.new_element.markdown.body = "foo"

        msg = ForwardMsg()
        msg.page_config_changed.title = "foo"

        ctx.enqueue(markdown_msg)
        with self.assertRaises(StreamlitAPIException):
            ctx.enqueue(msg)
예제 #6
0
    def enqueue(self, msg: ForwardMsg) -> None:
        """Add message into queue, possibly composing it with another message."""
        with self._lock:
            if not _is_composable_message(msg):
                self._queue.append(msg)
                return

            # If there's a Delta message with the same delta_path already in
            # the queue - meaning that it refers to the same location in
            # the app - we attempt to combine this new Delta into the old
            # one. This is an optimization that prevents redundant Deltas
            # from being sent to the frontend.
            delta_key = tuple(msg.metadata.delta_path)
            if delta_key in self._delta_index_map:
                index = self._delta_index_map[delta_key]
                old_msg = self._queue[index]
                composed_delta = _maybe_compose_deltas(old_msg.delta,
                                                       msg.delta)
                if composed_delta is not None:
                    new_msg = ForwardMsg()
                    new_msg.delta.CopyFrom(composed_delta)
                    new_msg.metadata.CopyFrom(msg.metadata)
                    self._queue[index] = new_msg
                    return

            # No composition occured. Append this message to the queue, and
            # store its index for potential future composition.
            self._delta_index_map[delta_key] = len(self._queue)
            self._queue.append(msg)
예제 #7
0
 def _create_session_state_changed_message(self) -> ForwardMsg:
     """Create and return a session_state_changed ForwardMsg."""
     msg = ForwardMsg()
     msg.session_state_changed.run_on_save = self._run_on_save
     msg.session_state_changed.script_is_running = (
         self._state == AppSessionState.APP_IS_RUNNING)
     return msg
예제 #8
0
 def _enqueue_new_report_message(self):
     self._report.generate_new_id()
     msg = ForwardMsg()
     msg.new_report.id = self._report.report_id
     msg.new_report.name = self._report.name
     msg.new_report.script_path = self._report.script_path
     self.enqueue(msg)
예제 #9
0
 def _enqueue_session_state_changed_message(self):
     msg = ForwardMsg()
     msg.session_state_changed.run_on_save = self._run_on_save
     msg.session_state_changed.report_is_running = (
         self._state == ReportSessionState.REPORT_IS_RUNNING
     )
     self.enqueue(msg)
예제 #10
0
    def handle_git_information_request(self) -> None:
        msg = ForwardMsg()

        try:
            from streamlit.git_util import GitRepo

            repo = GitRepo(self._session_data.script_path)

            repo_info = repo.get_repo_info()
            if repo_info is None:
                return

            repository_name, branch, module = repo_info

            msg.git_info_changed.repository = repository_name
            msg.git_info_changed.branch = branch
            msg.git_info_changed.module = module

            msg.git_info_changed.untracked_files[:] = repo.untracked_files
            msg.git_info_changed.uncommitted_files[:] = repo.uncommitted_files

            if repo.is_head_detached:
                msg.git_info_changed.state = GitInfo.GitStates.HEAD_DETACHED
            elif len(repo.ahead_commits) > 0:
                msg.git_info_changed.state = GitInfo.GitStates.AHEAD_OF_REMOTE
            else:
                msg.git_info_changed.state = GitInfo.GitStates.DEFAULT

            self.enqueue(msg)
        except Exception as e:
            # Users may never even install Git in the first place, so this
            # error requires no action. It can be useful for debugging.
            LOGGER.debug("Obtaining Git information produced an error",
                         exc_info=e)
예제 #11
0
 def _enqueue_session_state_changed_message(self) -> None:
     msg = ForwardMsg()
     msg.session_state_changed.run_on_save = self._run_on_save
     msg.session_state_changed.script_is_running = (
         self._state == AppSessionState.APP_IS_RUNNING
     )
     self.enqueue(msg)
예제 #12
0
    def _maybe_enqueue_initialize_message(self):
        if self._sent_initialize_message:
            return

        self._sent_initialize_message = True

        msg = ForwardMsg()
        imsg = msg.initialize

        imsg.config.sharing_enabled = (config.get_option('global.sharingMode')
                                       != 'off')
        LOGGER.debug('New browser connection: sharing_enabled=%s',
                     imsg.config.sharing_enabled)

        imsg.config.gather_usage_stats = (
            config.get_option('browser.gatherUsageStats'))
        LOGGER.debug('New browser connection: gather_usage_stats=%s',
                     imsg.config.gather_usage_stats)

        imsg.environment_info.streamlit_version = __version__
        imsg.environment_info.python_version = ('.'.join(
            map(str, sys.version_info)))

        imsg.session_state.run_on_save = self._run_on_save
        imsg.session_state.report_is_running = (
            self._state == ReportSessionState.REPORT_IS_RUNNING)

        imsg.user_info.installation_id = __installation_id__
        imsg.user_info.email = Credentials.get_current().activation.email

        self.enqueue(msg)
예제 #13
0
    def enqueue(self, msg):
        """Add message into queue, possibly composing it with another message.

        Parameters
        ----------
        msg : ForwardMsg
        """
        with self._lock:
            # Optimize only if it's a delta message
            if not msg.HasField("delta"):
                self._queue.append(msg)
            else:
                # Deltas are uniquely identified by the combination of their
                # container and ID.
                delta_path = (
                    msg.metadata.parent_block.container,
                    tuple(msg.metadata.parent_block.path),
                )
                delta_key = (delta_path, msg.metadata.delta_id)

                if delta_key in self._delta_index_map:
                    # Combine the previous message into the new message.
                    index = self._delta_index_map[delta_key]
                    old_msg = self._queue[index]
                    composed_delta = compose_deltas(old_msg.delta, msg.delta)
                    new_msg = ForwardMsg()
                    new_msg.delta.CopyFrom(composed_delta)
                    new_msg.metadata.CopyFrom(msg.metadata)
                    self._queue[index] = new_msg
                else:
                    # Append this message to the queue, and store its index
                    # for future combining.
                    self._delta_index_map[delta_key] = len(self._queue)
                    self._queue.append(msg)
예제 #14
0
    def _enqueue_new_session_message(self) -> None:
        msg = ForwardMsg()

        msg.new_session.script_run_id = _generate_scriptrun_id()
        msg.new_session.name = self._session_data.name
        msg.new_session.script_path = self._session_data.script_path

        _populate_config_msg(msg.new_session.config)
        _populate_theme_msg(msg.new_session.custom_theme)

        # Immutable session data. We send this every time a new session is
        # started, to avoid having to track whether the client has already
        # received it. It does not change from run to run; it's up to the
        # to perform one-time initialization only once.
        imsg = msg.new_session.initialize

        _populate_user_info_msg(imsg.user_info)

        imsg.environment_info.streamlit_version = __version__
        imsg.environment_info.python_version = ".".join(
            map(str, sys.version_info))

        imsg.session_state.run_on_save = self._run_on_save
        imsg.session_state.script_is_running = (
            self._state == AppSessionState.APP_IS_RUNNING)

        imsg.command_line = self._session_data.command_line
        imsg.session_id = self.id

        self.enqueue(msg)
예제 #15
0
    def handle_git_information_request(self):
        msg = ForwardMsg()

        try:
            from streamlit.git_util import GitRepo

            self._repo = GitRepo(self._report.script_path)

            repo, branch, module = self._repo.get_repo_info()

            msg.git_info_changed.repository = repo
            msg.git_info_changed.branch = branch
            msg.git_info_changed.module = module

            msg.git_info_changed.untracked_files[:] = self._repo.untracked_files
            msg.git_info_changed.uncommitted_files[:] = self._repo.uncommitted_files

            if self._repo.is_head_detached:
                msg.git_info_changed.state = GitInfo.GitStates.HEAD_DETACHED
            elif len(self._repo.ahead_commits) > 0:
                msg.git_info_changed.state = GitInfo.GitStates.AHEAD_OF_REMOTE
            else:
                msg.git_info_changed.state = GitInfo.GitStates.DEFAULT

            self.enqueue(msg)
        except:
            pass
예제 #16
0
    def _send_message(self, session_info: SessionInfo, msg: ForwardMsg) -> None:
        """Send a message to a client.

        If the client is likely to have already cached the message, we may
        instead send a "reference" message that contains only the hash of the
        message.

        Parameters
        ----------
        session_info : SessionInfo
            The SessionInfo associated with websocket
        msg : ForwardMsg
            The message to send to the client

        """
        msg.metadata.cacheable = is_cacheable_msg(msg)
        msg_to_send = msg
        if msg.metadata.cacheable:
            populate_hash_if_needed(msg)

            if self._message_cache.has_message_reference(
                msg, session_info.session, session_info.report_run_count
            ):

                # This session has probably cached this message. Send
                # a reference instead.
                LOGGER.debug("Sending cached message ref (hash=%s)" % msg.hash)
                msg_to_send = create_reference_msg(msg)

            # Cache the message so it can be referenced in the future.
            # If the message is already cached, this will reset its
            # age.
            LOGGER.debug("Caching message (hash=%s)" % msg.hash)
            self._message_cache.add_message(
                msg, session_info.session, session_info.report_run_count
            )

        # If this was a `report_finished` message, we increment the
        # report_run_count for this session, and update the cache
        if (
            msg.WhichOneof("type") == "report_finished"
            and msg.report_finished == ForwardMsg.FINISHED_SUCCESSFULLY
        ):
            LOGGER.debug(
                "Report finished successfully; "
                "removing expired entries from MessageCache "
                "(max_age=%s)",
                config.get_option("global.maxCachedMessageAge"),
            )
            session_info.report_run_count += 1
            self._message_cache.remove_expired_session_entries(
                session_info.session, session_info.report_run_count
            )

        # Ship it off!
        if session_info.ws is not None:
            session_info.ws.write_message(
                serialize_forward_msg(msg_to_send), binary=True
            )
예제 #17
0
    def _enqueue_new_report_message(self):
        self._report.generate_new_id()
        msg = ForwardMsg()
        msg.new_report.report_id = self._report.report_id
        msg.new_report.name = self._report.name
        msg.new_report.script_path = self._report.script_path

        # git deploy params
        deploy_params = self.get_deploy_params()
        if deploy_params is not None:
            repo, branch, module = deploy_params
            msg.new_report.deploy_params.repository = repo
            msg.new_report.deploy_params.branch = branch
            msg.new_report.deploy_params.module = module

        # Immutable session data. We send this every time a new report is
        # started, to avoid having to track whether the client has already
        # received it. It does not change from run to run; it's up to the
        # to perform one-time initialization only once.
        imsg = msg.new_report.initialize
        imsg.config.sharing_enabled = config.get_option(
            "global.sharingMode") != "off"

        imsg.config.gather_usage_stats = config.get_option(
            "browser.gatherUsageStats")

        imsg.config.max_cached_message_age = config.get_option(
            "global.maxCachedMessageAge")

        imsg.config.mapbox_token = config.get_option("mapbox.token")

        imsg.config.allow_run_on_save = config.get_option(
            "server.allowRunOnSave")

        imsg.environment_info.streamlit_version = __version__
        imsg.environment_info.python_version = ".".join(
            map(str, sys.version_info))

        imsg.session_state.run_on_save = self._run_on_save
        imsg.session_state.report_is_running = (
            self._state == ReportSessionState.REPORT_IS_RUNNING)

        imsg.user_info.installation_id = Installation.instance(
        ).installation_id
        imsg.user_info.installation_id_v1 = Installation.instance(
        ).installation_id_v1
        imsg.user_info.installation_id_v2 = Installation.instance(
        ).installation_id_v2
        imsg.user_info.installation_id_v3 = Installation.instance(
        ).installation_id_v3
        if Credentials.get_current().activation:
            imsg.user_info.email = Credentials.get_current().activation.email
        else:
            imsg.user_info.email = ""

        imsg.command_line = self._report.command_line
        imsg.session_id = self.id

        self.enqueue(msg)
예제 #18
0
    def enqueue(self, msg: ForwardMsg) -> None:
        if msg.HasField(
                "page_config_changed") and not self._set_page_config_allowed:
            raise StreamlitAPIException(
                "`set_page_config()` can only be called once per app, " +
                "and must be called as the first Streamlit command in your script.\n\n"
                + "For more information refer to the [docs]" +
                "(https://docs.streamlit.io/library/api-reference/utilities/st.set_page_config)."
            )

        # We want to disallow set_page config if one of the following occurs:
        # - set_page_config was called on this message
        # - The script has already started and a different st call occurs (a delta)
        if msg.HasField("page_config_changed") or (msg.HasField("delta") and
                                                   self._has_script_started):
            self._set_page_config_allowed = False

        self._enqueue(msg)
    def _maybe_enqueue_initialize_message(self):
        if self._sent_initialize_message:
            return

        self._sent_initialize_message = True

        msg = ForwardMsg()
        imsg = msg.initialize

        imsg.config.sharing_enabled = config.get_option(
            "global.sharingMode") != "off"

        imsg.config.gather_usage_stats = config.get_option(
            "browser.gatherUsageStats")

        imsg.config.max_cached_message_age = config.get_option(
            "global.maxCachedMessageAge")

        imsg.config.mapbox_token = config.get_option("mapbox.token")

        imsg.config.allow_run_on_save = config.get_option(
            "server.allowRunOnSave")

        LOGGER.debug(
            "New browser connection: "
            "gather_usage_stats=%s, "
            "sharing_enabled=%s, "
            "max_cached_message_age=%s",
            imsg.config.gather_usage_stats,
            imsg.config.sharing_enabled,
            imsg.config.max_cached_message_age,
        )

        imsg.environment_info.streamlit_version = __version__
        imsg.environment_info.python_version = ".".join(
            map(str, sys.version_info))

        imsg.session_state.run_on_save = self._run_on_save
        imsg.session_state.report_is_running = (
            self._state == ReportSessionState.REPORT_IS_RUNNING)

        imsg.user_info.installation_id = Installation.instance(
        ).installation_id
        imsg.user_info.installation_id_v1 = Installation.instance(
        ).installation_id_v1
        imsg.user_info.installation_id_v2 = Installation.instance(
        ).installation_id_v2
        if Credentials.get_current().activation:
            imsg.user_info.email = Credentials.get_current().activation.email
        else:
            imsg.user_info.email = ""

        imsg.command_line = self._report.command_line
        imsg.session_id = self.id

        self.enqueue(msg)
예제 #20
0
def serialize_forward_msg(msg: ForwardMsg) -> bytes:
    """Serialize a ForwardMsg to send to a client.

    If the message is too large, it will be converted to an exception message
    instead.
    """
    populate_hash_if_needed(msg)
    msg_str = msg.SerializeToString()

    if len(msg_str) > get_max_message_size_bytes():
        import streamlit.elements.exception as exception

        # Overwrite the offending ForwardMsg.delta with an error to display.
        # This assumes that the size limit wasn't exceeded due to metadata.
        exception.marshall(msg.delta.new_element.exception,
                           MessageSizeError(msg_str))
        msg_str = msg.SerializeToString()

    return msg_str
예제 #21
0
    def test_set_page_config_immutable(self, _1):
        """st.set_page_config must be called at most once"""
        file_mgr = MagicMock(spec=UploadedFileManager)
        rs = ReportSession(None, "", "", file_mgr)

        msg = ForwardMsg()
        msg.page_config_changed.title = "foo"

        rs.enqueue(msg)
        with self.assertRaises(StreamlitAPIException):
            rs.enqueue(msg)
예제 #22
0
def _is_composable_message(msg: ForwardMsg) -> bool:
    """True if the ForwardMsg is potentially composable with other ForwardMsgs."""
    if not msg.HasField("delta"):
        # Non-delta messages are never composable.
        return False

    # We never compose add_rows messages in Python, because the add_rows
    # operation can raise errors, and we don't have a good way of handling
    # those errors in the message queue.
    delta_type = msg.delta.WhichOneof("type")
    return delta_type != "add_rows" and delta_type != "arrow_add_rows"
예제 #23
0
def serialize_forward_msg(msg: ForwardMsg) -> bytes:
    """Serialize a ForwardMsg to send to a client.

    If the message is too large, it will be converted to an exception message
    instead.
    """
    populate_hash_if_needed(msg)
    msg_str = msg.SerializeToString()

    if len(msg_str) > MESSAGE_SIZE_LIMIT:
        import streamlit.elements.exception as exception

        error = RuntimeError(
            f"Data of size {len(msg_str)/1e6:.1f}MB exceeds write limit of {MESSAGE_SIZE_LIMIT/1e6}MB"
        )
        # Overwrite the offending ForwardMsg.delta with an error to display.
        # This assumes that the size limit wasn't exceeded due to metadata.
        exception.marshall(msg.delta.new_element.exception, error)
        msg_str = msg.SerializeToString()

    return msg_str
예제 #24
0
    def test_set_page_config_immutable(self):
        """st.set_page_config must be called at most once"""

        fake_enqueue = lambda msg: None
        ctx = ReportContext("TestSessionID", fake_enqueue, "", None, None, None)

        msg = ForwardMsg()
        msg.page_config_changed.title = "foo"

        ctx.enqueue(msg)
        with self.assertRaises(StreamlitAPIException):
            ctx.enqueue(msg)
예제 #25
0
    def test_can_specify_all_options(self, patched_config):
        patched_config.get_options_for_section.side_effect = (
            # Specifies all options by default.
            _mock_get_options_for_section())

        msg = ForwardMsg()
        new_report_msg = msg.new_report
        report_session._populate_theme_msg(new_report_msg.custom_theme)

        self.assertEqual(new_report_msg.HasField("custom_theme"), True)
        self.assertEqual(new_report_msg.custom_theme.primary_color, "coral")
        self.assertEqual(new_report_msg.custom_theme.background_color, "white")
예제 #26
0
    def test_disallow_set_page_config_twice(self):
        """st.set_page_config cannot be called twice"""

        fake_enqueue = lambda msg: None
        ctx = ScriptRunContext(
            "TestSessionID",
            fake_enqueue,
            "",
            SessionState(),
            UploadedFileManager(),
        )

        ctx.on_script_start()

        msg = ForwardMsg()
        msg.page_config_changed.title = "foo"
        ctx.enqueue(msg)

        with self.assertRaises(StreamlitAPIException):
            same_msg = ForwardMsg()
            same_msg.page_config_changed.title = "bar"
            ctx.enqueue(same_msg)
예제 #27
0
    def test_logs_warning_if_base_invalid(self, patched_config,
                                          patched_logger):
        patched_config.get_options_for_section.side_effect = (
            _mock_get_options_for_section({"base": "blah"}))

        msg = ForwardMsg()
        new_session_msg = msg.new_session
        app_session._populate_theme_msg(new_session_msg.custom_theme)

        patched_logger.warning.assert_called_once_with(
            '"blah" is an invalid value for theme.base.'
            " Allowed values include ['light', 'dark']. Setting theme.base to \"light\"."
        )
예제 #28
0
    def test_logs_warning_if_font_invalid(self, patched_config,
                                          patched_logger):
        patched_config.get_options_for_section.side_effect = (
            _mock_get_options_for_section({"font": "comic sans"}))

        msg = ForwardMsg()
        new_session_msg = msg.new_session
        app_session._populate_theme_msg(new_session_msg.custom_theme)

        patched_logger.warning.assert_called_once_with(
            '"comic sans" is an invalid value for theme.font.'
            " Allowed values include ['sans serif', 'serif', 'monospace']. Setting theme.font to \"sans serif\"."
        )
def create_reference_msg(msg: ForwardMsg) -> ForwardMsg:
    """Create a ForwardMsg that refers to the given message via its hash.

    The reference message will also get a copy of the source message's
    metadata.

    Parameters
    ----------
    msg : ForwardMsg
        The ForwardMsg to create the reference to.

    Returns
    -------
    ForwardMsg
        A new ForwardMsg that "points" to the original message via the
        ref_hash field.

    """
    ref_msg = ForwardMsg()
    ref_msg.ref_hash = populate_hash_if_needed(msg)
    ref_msg.metadata.CopyFrom(msg.metadata)
    return ref_msg
예제 #30
0
    def test_no_custom_theme_prop_if_no_theme(self, patched_config):
        patched_config.get_options_for_section.side_effect = (
            _mock_get_options_for_section({
                "primaryColor": None,
                "backgroundColor": None,
                "secondaryBackgroundColor": None,
                "textColor": None,
            }))

        msg = ForwardMsg()
        new_report_msg = msg.new_report
        report_session._populate_theme_msg(new_report_msg.custom_theme)

        self.assertEqual(new_report_msg.HasField("custom_theme"), False)