def _enqueue_script_finished_message(self, status): """Enqueue a script_finished ForwardMsg. Parameters ---------- status : ScriptFinishedStatus """ msg = ForwardMsg() msg.script_finished = status self.enqueue(msg)
def enqueue(self, msg: ForwardMsg) -> None: if msg.HasField( "page_config_changed") and not self._set_page_config_allowed: raise StreamlitAPIException( "`set_page_config()` can only be called once per app, " + "and must be called as the first Streamlit command in your script.\n\n" + "For more information refer to the [docs]" + "(https://docs.streamlit.io/en/stable/api.html#streamlit.set_page_config)." ) if msg.HasField("delta") or msg.HasField("page_config_changed"): self._set_page_config_allowed = False self._enqueue(msg)
def handle_save_request(self, ws): """Save serialized version of report deltas to the cloud.""" @tornado.gen.coroutine def progress(percent): progress_msg = ForwardMsg() progress_msg.upload_report_progress = percent yield ws.write_message(progress_msg.SerializeToString(), binary=True) # Indicate that the save is starting. try: yield progress(0) url = yield self._save_final_report(progress) # Indicate that the save is done. progress_msg = ForwardMsg() progress_msg.report_uploaded = url yield ws.write_message(progress_msg.SerializeToString(), binary=True) except Exception as e: # Horrible hack to show something if something breaks. err_msg = '%s: %s' % (type(e).__name__, str(e) or 'No further details.') progress_msg = ForwardMsg() progress_msg.report_uploaded = err_msg yield ws.write_message(progress_msg.SerializeToString(), binary=True) raise e
def test_set_page_config_first(self, _1): """st.set_page_config must be called before other st commands""" file_mgr = MagicMock(spec=UploadedFileManager) rs = ReportSession(None, "", "", file_mgr) markdown_msg = ForwardMsg() markdown_msg.delta.new_element.markdown.body = "foo" msg = ForwardMsg() msg.page_config_changed.title = "foo" rs.enqueue(markdown_msg) with self.assertRaises(StreamlitAPIException): rs.enqueue(msg)
def test_set_page_config_first(self): """st.set_page_config must be called before other st commands""" fake_enqueue = lambda msg: None ctx = ReportContext("TestSessionID", fake_enqueue, "", None, None, None) markdown_msg = ForwardMsg() markdown_msg.delta.new_element.markdown.body = "foo" msg = ForwardMsg() msg.page_config_changed.title = "foo" ctx.enqueue(markdown_msg) with self.assertRaises(StreamlitAPIException): ctx.enqueue(msg)
def enqueue(self, msg: ForwardMsg) -> None: """Add message into queue, possibly composing it with another message.""" with self._lock: if not _is_composable_message(msg): self._queue.append(msg) return # If there's a Delta message with the same delta_path already in # the queue - meaning that it refers to the same location in # the app - we attempt to combine this new Delta into the old # one. This is an optimization that prevents redundant Deltas # from being sent to the frontend. delta_key = tuple(msg.metadata.delta_path) if delta_key in self._delta_index_map: index = self._delta_index_map[delta_key] old_msg = self._queue[index] composed_delta = _maybe_compose_deltas(old_msg.delta, msg.delta) if composed_delta is not None: new_msg = ForwardMsg() new_msg.delta.CopyFrom(composed_delta) new_msg.metadata.CopyFrom(msg.metadata) self._queue[index] = new_msg return # No composition occured. Append this message to the queue, and # store its index for potential future composition. self._delta_index_map[delta_key] = len(self._queue) self._queue.append(msg)
def _create_session_state_changed_message(self) -> ForwardMsg: """Create and return a session_state_changed ForwardMsg.""" msg = ForwardMsg() msg.session_state_changed.run_on_save = self._run_on_save msg.session_state_changed.script_is_running = ( self._state == AppSessionState.APP_IS_RUNNING) return msg
def _enqueue_new_report_message(self): self._report.generate_new_id() msg = ForwardMsg() msg.new_report.id = self._report.report_id msg.new_report.name = self._report.name msg.new_report.script_path = self._report.script_path self.enqueue(msg)
def _enqueue_session_state_changed_message(self): msg = ForwardMsg() msg.session_state_changed.run_on_save = self._run_on_save msg.session_state_changed.report_is_running = ( self._state == ReportSessionState.REPORT_IS_RUNNING ) self.enqueue(msg)
def handle_git_information_request(self) -> None: msg = ForwardMsg() try: from streamlit.git_util import GitRepo repo = GitRepo(self._session_data.script_path) repo_info = repo.get_repo_info() if repo_info is None: return repository_name, branch, module = repo_info msg.git_info_changed.repository = repository_name msg.git_info_changed.branch = branch msg.git_info_changed.module = module msg.git_info_changed.untracked_files[:] = repo.untracked_files msg.git_info_changed.uncommitted_files[:] = repo.uncommitted_files if repo.is_head_detached: msg.git_info_changed.state = GitInfo.GitStates.HEAD_DETACHED elif len(repo.ahead_commits) > 0: msg.git_info_changed.state = GitInfo.GitStates.AHEAD_OF_REMOTE else: msg.git_info_changed.state = GitInfo.GitStates.DEFAULT self.enqueue(msg) except Exception as e: # Users may never even install Git in the first place, so this # error requires no action. It can be useful for debugging. LOGGER.debug("Obtaining Git information produced an error", exc_info=e)
def _enqueue_session_state_changed_message(self) -> None: msg = ForwardMsg() msg.session_state_changed.run_on_save = self._run_on_save msg.session_state_changed.script_is_running = ( self._state == AppSessionState.APP_IS_RUNNING ) self.enqueue(msg)
def _maybe_enqueue_initialize_message(self): if self._sent_initialize_message: return self._sent_initialize_message = True msg = ForwardMsg() imsg = msg.initialize imsg.config.sharing_enabled = (config.get_option('global.sharingMode') != 'off') LOGGER.debug('New browser connection: sharing_enabled=%s', imsg.config.sharing_enabled) imsg.config.gather_usage_stats = ( config.get_option('browser.gatherUsageStats')) LOGGER.debug('New browser connection: gather_usage_stats=%s', imsg.config.gather_usage_stats) imsg.environment_info.streamlit_version = __version__ imsg.environment_info.python_version = ('.'.join( map(str, sys.version_info))) imsg.session_state.run_on_save = self._run_on_save imsg.session_state.report_is_running = ( self._state == ReportSessionState.REPORT_IS_RUNNING) imsg.user_info.installation_id = __installation_id__ imsg.user_info.email = Credentials.get_current().activation.email self.enqueue(msg)
def enqueue(self, msg): """Add message into queue, possibly composing it with another message. Parameters ---------- msg : ForwardMsg """ with self._lock: # Optimize only if it's a delta message if not msg.HasField("delta"): self._queue.append(msg) else: # Deltas are uniquely identified by the combination of their # container and ID. delta_path = ( msg.metadata.parent_block.container, tuple(msg.metadata.parent_block.path), ) delta_key = (delta_path, msg.metadata.delta_id) if delta_key in self._delta_index_map: # Combine the previous message into the new message. index = self._delta_index_map[delta_key] old_msg = self._queue[index] composed_delta = compose_deltas(old_msg.delta, msg.delta) new_msg = ForwardMsg() new_msg.delta.CopyFrom(composed_delta) new_msg.metadata.CopyFrom(msg.metadata) self._queue[index] = new_msg else: # Append this message to the queue, and store its index # for future combining. self._delta_index_map[delta_key] = len(self._queue) self._queue.append(msg)
def _enqueue_new_session_message(self) -> None: msg = ForwardMsg() msg.new_session.script_run_id = _generate_scriptrun_id() msg.new_session.name = self._session_data.name msg.new_session.script_path = self._session_data.script_path _populate_config_msg(msg.new_session.config) _populate_theme_msg(msg.new_session.custom_theme) # Immutable session data. We send this every time a new session is # started, to avoid having to track whether the client has already # received it. It does not change from run to run; it's up to the # to perform one-time initialization only once. imsg = msg.new_session.initialize _populate_user_info_msg(imsg.user_info) imsg.environment_info.streamlit_version = __version__ imsg.environment_info.python_version = ".".join( map(str, sys.version_info)) imsg.session_state.run_on_save = self._run_on_save imsg.session_state.script_is_running = ( self._state == AppSessionState.APP_IS_RUNNING) imsg.command_line = self._session_data.command_line imsg.session_id = self.id self.enqueue(msg)
def handle_git_information_request(self): msg = ForwardMsg() try: from streamlit.git_util import GitRepo self._repo = GitRepo(self._report.script_path) repo, branch, module = self._repo.get_repo_info() msg.git_info_changed.repository = repo msg.git_info_changed.branch = branch msg.git_info_changed.module = module msg.git_info_changed.untracked_files[:] = self._repo.untracked_files msg.git_info_changed.uncommitted_files[:] = self._repo.uncommitted_files if self._repo.is_head_detached: msg.git_info_changed.state = GitInfo.GitStates.HEAD_DETACHED elif len(self._repo.ahead_commits) > 0: msg.git_info_changed.state = GitInfo.GitStates.AHEAD_OF_REMOTE else: msg.git_info_changed.state = GitInfo.GitStates.DEFAULT self.enqueue(msg) except: pass
def _send_message(self, session_info: SessionInfo, msg: ForwardMsg) -> None: """Send a message to a client. If the client is likely to have already cached the message, we may instead send a "reference" message that contains only the hash of the message. Parameters ---------- session_info : SessionInfo The SessionInfo associated with websocket msg : ForwardMsg The message to send to the client """ msg.metadata.cacheable = is_cacheable_msg(msg) msg_to_send = msg if msg.metadata.cacheable: populate_hash_if_needed(msg) if self._message_cache.has_message_reference( msg, session_info.session, session_info.report_run_count ): # This session has probably cached this message. Send # a reference instead. LOGGER.debug("Sending cached message ref (hash=%s)" % msg.hash) msg_to_send = create_reference_msg(msg) # Cache the message so it can be referenced in the future. # If the message is already cached, this will reset its # age. LOGGER.debug("Caching message (hash=%s)" % msg.hash) self._message_cache.add_message( msg, session_info.session, session_info.report_run_count ) # If this was a `report_finished` message, we increment the # report_run_count for this session, and update the cache if ( msg.WhichOneof("type") == "report_finished" and msg.report_finished == ForwardMsg.FINISHED_SUCCESSFULLY ): LOGGER.debug( "Report finished successfully; " "removing expired entries from MessageCache " "(max_age=%s)", config.get_option("global.maxCachedMessageAge"), ) session_info.report_run_count += 1 self._message_cache.remove_expired_session_entries( session_info.session, session_info.report_run_count ) # Ship it off! if session_info.ws is not None: session_info.ws.write_message( serialize_forward_msg(msg_to_send), binary=True )
def _enqueue_new_report_message(self): self._report.generate_new_id() msg = ForwardMsg() msg.new_report.report_id = self._report.report_id msg.new_report.name = self._report.name msg.new_report.script_path = self._report.script_path # git deploy params deploy_params = self.get_deploy_params() if deploy_params is not None: repo, branch, module = deploy_params msg.new_report.deploy_params.repository = repo msg.new_report.deploy_params.branch = branch msg.new_report.deploy_params.module = module # Immutable session data. We send this every time a new report is # started, to avoid having to track whether the client has already # received it. It does not change from run to run; it's up to the # to perform one-time initialization only once. imsg = msg.new_report.initialize imsg.config.sharing_enabled = config.get_option( "global.sharingMode") != "off" imsg.config.gather_usage_stats = config.get_option( "browser.gatherUsageStats") imsg.config.max_cached_message_age = config.get_option( "global.maxCachedMessageAge") imsg.config.mapbox_token = config.get_option("mapbox.token") imsg.config.allow_run_on_save = config.get_option( "server.allowRunOnSave") imsg.environment_info.streamlit_version = __version__ imsg.environment_info.python_version = ".".join( map(str, sys.version_info)) imsg.session_state.run_on_save = self._run_on_save imsg.session_state.report_is_running = ( self._state == ReportSessionState.REPORT_IS_RUNNING) imsg.user_info.installation_id = Installation.instance( ).installation_id imsg.user_info.installation_id_v1 = Installation.instance( ).installation_id_v1 imsg.user_info.installation_id_v2 = Installation.instance( ).installation_id_v2 imsg.user_info.installation_id_v3 = Installation.instance( ).installation_id_v3 if Credentials.get_current().activation: imsg.user_info.email = Credentials.get_current().activation.email else: imsg.user_info.email = "" imsg.command_line = self._report.command_line imsg.session_id = self.id self.enqueue(msg)
def enqueue(self, msg: ForwardMsg) -> None: if msg.HasField( "page_config_changed") and not self._set_page_config_allowed: raise StreamlitAPIException( "`set_page_config()` can only be called once per app, " + "and must be called as the first Streamlit command in your script.\n\n" + "For more information refer to the [docs]" + "(https://docs.streamlit.io/library/api-reference/utilities/st.set_page_config)." ) # We want to disallow set_page config if one of the following occurs: # - set_page_config was called on this message # - The script has already started and a different st call occurs (a delta) if msg.HasField("page_config_changed") or (msg.HasField("delta") and self._has_script_started): self._set_page_config_allowed = False self._enqueue(msg)
def _maybe_enqueue_initialize_message(self): if self._sent_initialize_message: return self._sent_initialize_message = True msg = ForwardMsg() imsg = msg.initialize imsg.config.sharing_enabled = config.get_option( "global.sharingMode") != "off" imsg.config.gather_usage_stats = config.get_option( "browser.gatherUsageStats") imsg.config.max_cached_message_age = config.get_option( "global.maxCachedMessageAge") imsg.config.mapbox_token = config.get_option("mapbox.token") imsg.config.allow_run_on_save = config.get_option( "server.allowRunOnSave") LOGGER.debug( "New browser connection: " "gather_usage_stats=%s, " "sharing_enabled=%s, " "max_cached_message_age=%s", imsg.config.gather_usage_stats, imsg.config.sharing_enabled, imsg.config.max_cached_message_age, ) imsg.environment_info.streamlit_version = __version__ imsg.environment_info.python_version = ".".join( map(str, sys.version_info)) imsg.session_state.run_on_save = self._run_on_save imsg.session_state.report_is_running = ( self._state == ReportSessionState.REPORT_IS_RUNNING) imsg.user_info.installation_id = Installation.instance( ).installation_id imsg.user_info.installation_id_v1 = Installation.instance( ).installation_id_v1 imsg.user_info.installation_id_v2 = Installation.instance( ).installation_id_v2 if Credentials.get_current().activation: imsg.user_info.email = Credentials.get_current().activation.email else: imsg.user_info.email = "" imsg.command_line = self._report.command_line imsg.session_id = self.id self.enqueue(msg)
def serialize_forward_msg(msg: ForwardMsg) -> bytes: """Serialize a ForwardMsg to send to a client. If the message is too large, it will be converted to an exception message instead. """ populate_hash_if_needed(msg) msg_str = msg.SerializeToString() if len(msg_str) > get_max_message_size_bytes(): import streamlit.elements.exception as exception # Overwrite the offending ForwardMsg.delta with an error to display. # This assumes that the size limit wasn't exceeded due to metadata. exception.marshall(msg.delta.new_element.exception, MessageSizeError(msg_str)) msg_str = msg.SerializeToString() return msg_str
def test_set_page_config_immutable(self, _1): """st.set_page_config must be called at most once""" file_mgr = MagicMock(spec=UploadedFileManager) rs = ReportSession(None, "", "", file_mgr) msg = ForwardMsg() msg.page_config_changed.title = "foo" rs.enqueue(msg) with self.assertRaises(StreamlitAPIException): rs.enqueue(msg)
def _is_composable_message(msg: ForwardMsg) -> bool: """True if the ForwardMsg is potentially composable with other ForwardMsgs.""" if not msg.HasField("delta"): # Non-delta messages are never composable. return False # We never compose add_rows messages in Python, because the add_rows # operation can raise errors, and we don't have a good way of handling # those errors in the message queue. delta_type = msg.delta.WhichOneof("type") return delta_type != "add_rows" and delta_type != "arrow_add_rows"
def serialize_forward_msg(msg: ForwardMsg) -> bytes: """Serialize a ForwardMsg to send to a client. If the message is too large, it will be converted to an exception message instead. """ populate_hash_if_needed(msg) msg_str = msg.SerializeToString() if len(msg_str) > MESSAGE_SIZE_LIMIT: import streamlit.elements.exception as exception error = RuntimeError( f"Data of size {len(msg_str)/1e6:.1f}MB exceeds write limit of {MESSAGE_SIZE_LIMIT/1e6}MB" ) # Overwrite the offending ForwardMsg.delta with an error to display. # This assumes that the size limit wasn't exceeded due to metadata. exception.marshall(msg.delta.new_element.exception, error) msg_str = msg.SerializeToString() return msg_str
def test_set_page_config_immutable(self): """st.set_page_config must be called at most once""" fake_enqueue = lambda msg: None ctx = ReportContext("TestSessionID", fake_enqueue, "", None, None, None) msg = ForwardMsg() msg.page_config_changed.title = "foo" ctx.enqueue(msg) with self.assertRaises(StreamlitAPIException): ctx.enqueue(msg)
def test_can_specify_all_options(self, patched_config): patched_config.get_options_for_section.side_effect = ( # Specifies all options by default. _mock_get_options_for_section()) msg = ForwardMsg() new_report_msg = msg.new_report report_session._populate_theme_msg(new_report_msg.custom_theme) self.assertEqual(new_report_msg.HasField("custom_theme"), True) self.assertEqual(new_report_msg.custom_theme.primary_color, "coral") self.assertEqual(new_report_msg.custom_theme.background_color, "white")
def test_disallow_set_page_config_twice(self): """st.set_page_config cannot be called twice""" fake_enqueue = lambda msg: None ctx = ScriptRunContext( "TestSessionID", fake_enqueue, "", SessionState(), UploadedFileManager(), ) ctx.on_script_start() msg = ForwardMsg() msg.page_config_changed.title = "foo" ctx.enqueue(msg) with self.assertRaises(StreamlitAPIException): same_msg = ForwardMsg() same_msg.page_config_changed.title = "bar" ctx.enqueue(same_msg)
def test_logs_warning_if_base_invalid(self, patched_config, patched_logger): patched_config.get_options_for_section.side_effect = ( _mock_get_options_for_section({"base": "blah"})) msg = ForwardMsg() new_session_msg = msg.new_session app_session._populate_theme_msg(new_session_msg.custom_theme) patched_logger.warning.assert_called_once_with( '"blah" is an invalid value for theme.base.' " Allowed values include ['light', 'dark']. Setting theme.base to \"light\"." )
def test_logs_warning_if_font_invalid(self, patched_config, patched_logger): patched_config.get_options_for_section.side_effect = ( _mock_get_options_for_section({"font": "comic sans"})) msg = ForwardMsg() new_session_msg = msg.new_session app_session._populate_theme_msg(new_session_msg.custom_theme) patched_logger.warning.assert_called_once_with( '"comic sans" is an invalid value for theme.font.' " Allowed values include ['sans serif', 'serif', 'monospace']. Setting theme.font to \"sans serif\"." )
def create_reference_msg(msg: ForwardMsg) -> ForwardMsg: """Create a ForwardMsg that refers to the given message via its hash. The reference message will also get a copy of the source message's metadata. Parameters ---------- msg : ForwardMsg The ForwardMsg to create the reference to. Returns ------- ForwardMsg A new ForwardMsg that "points" to the original message via the ref_hash field. """ ref_msg = ForwardMsg() ref_msg.ref_hash = populate_hash_if_needed(msg) ref_msg.metadata.CopyFrom(msg.metadata) return ref_msg
def test_no_custom_theme_prop_if_no_theme(self, patched_config): patched_config.get_options_for_section.side_effect = ( _mock_get_options_for_section({ "primaryColor": None, "backgroundColor": None, "secondaryBackgroundColor": None, "textColor": None, })) msg = ForwardMsg() new_report_msg = msg.new_report report_session._populate_theme_msg(new_report_msg.custom_theme) self.assertEqual(new_report_msg.HasField("custom_theme"), False)