def test_dequeue(self): """Test that we can enqueue and dequeue on different threads""" queue = ScriptRequestQueue() # This should return immediately self.assertEqual((None, None), queue.dequeue()) lock = Lock() dequeued_evt = [None] def get_event(): with lock: return dequeued_evt[0] def set_event(value): with lock: dequeued_evt[0] = value def do_dequeue(): event = None while event is None: event, _ = queue.dequeue() set_event(event) thread = Thread(target=do_dequeue, name="test_dequeue") thread.start() self.assertIsNone(get_event()) queue.enqueue(ScriptRequest.STOP) time.sleep(0.1) self.assertEqual(ScriptRequest.STOP, get_event()) thread.join(timeout=0.25) self.assertFalse(thread.is_alive())
class TestScriptRunner(ScriptRunner): """Subclasses ScriptRunner to provide some testing features.""" def __init__(self, script_name): """Initializes the ScriptRunner for the given script_name""" # DeltaGenerator deltas will be enqueued into self.report_queue. self.report_queue = ReportQueue() def enqueue_fn(msg): self.report_queue.enqueue(msg) self.maybe_handle_execution_control_request() self.script_request_queue = ScriptRequestQueue() script_path = os.path.join(os.path.dirname(__file__), "test_data", script_name) super(TestScriptRunner, self).__init__( session_id="test session id", report=Report(script_path, "test command line"), enqueue_forward_msg=enqueue_fn, client_state=ClientState(), request_queue=self.script_request_queue, ) # Accumulates uncaught exceptions thrown by our run thread. self.script_thread_exceptions = [] # Accumulates all ScriptRunnerEvents emitted by us. self.events = [] def record_event(event, **kwargs): self.events.append(event) self.on_event.connect(record_event, weak=False) def enqueue_rerun(self, argv=None, widget_states=None): self.script_request_queue.enqueue( ScriptRequest.RERUN, RerunData(widget_states=widget_states)) def enqueue_stop(self): self.script_request_queue.enqueue(ScriptRequest.STOP) def enqueue_shutdown(self): self.script_request_queue.enqueue(ScriptRequest.SHUTDOWN) def _process_request_queue(self): try: super(TestScriptRunner, self)._process_request_queue() except BaseException as e: self.script_thread_exceptions.append(e) def _run_script(self, rerun_data): self.report_queue.clear() super(TestScriptRunner, self)._run_script(rerun_data) def join(self): """Joins the run thread, if it was started""" if self._script_thread is not None: self._script_thread.join() def clear_deltas(self): """Clear all delta messages from our ReportQueue""" self.report_queue.clear() def deltas(self) -> List[Delta]: """Return the delta messages in our ReportQueue""" return [ msg.delta for msg in self.report_queue._queue if msg.HasField("delta") ] def elements(self) -> List[Element]: """Return the delta.new_element messages in our ReportQueue.""" return [delta.new_element for delta in self.deltas()] def text_deltas(self) -> List[str]: """Return the string contents of text deltas in our ReportQueue""" return [ element.text.body for element in self.elements() if element.WhichOneof("type") == "text" ] def get_widget_id(self, widget_type, label): """Returns the id of the widget with the specified type and label""" for delta in self.deltas(): new_element = getattr(delta, "new_element", None) widget = getattr(new_element, widget_type, None) widget_label = getattr(widget, "label", None) if widget_label == label: return widget.id return None
class ReportSession(object): """ Contains session data for a single "user" of an active report (that is, a connected browser tab). Each ReportSession has its own Report, root DeltaGenerator, ScriptRunner, and widget state. A ReportSession is attached to each thread involved in running its Report. """ def __init__(self, ioloop, script_path, command_line, uploaded_file_manager): """Initialize the ReportSession. Parameters ---------- ioloop : tornado.ioloop.IOLoop The Tornado IOLoop that we're running within. script_path : str Path of the Python file from which this report is generated. command_line : str Command line as input by the user. uploaded_file_manager : UploadedFileManager The server's UploadedFileManager. """ # Each ReportSession has a unique string ID. self.id = str(uuid.uuid4()) self._ioloop = ioloop self._report = Report(script_path, command_line) self._uploaded_file_mgr = uploaded_file_manager self._state = ReportSessionState.REPORT_NOT_RUNNING # Need to remember the client state here because when a script reruns # due to the source code changing we need to pass in the previous client state. self._client_state = ClientState() self._local_sources_watcher = LocalSourcesWatcher( self._report, self._on_source_file_changed) self._sent_initialize_message = False self._storage = None self._maybe_reuse_previous_run = False self._run_on_save = config.get_option("server.runOnSave") # The ScriptRequestQueue is the means by which we communicate # with the active ScriptRunner. self._script_request_queue = ScriptRequestQueue() self._scriptrunner = None LOGGER.debug("ReportSession initialized (id=%s)", self.id) def flush_browser_queue(self): """Clear the report queue and return the messages it contained. The Server calls this periodically to deliver new messages to the browser connected to this report. Returns ------- list[ForwardMsg] The messages that were removed from the queue and should be delivered to the browser. """ return self._report.flush_browser_queue() def shutdown(self): """Shut down the ReportSession. It's an error to use a ReportSession after it's been shut down. """ if self._state != ReportSessionState.SHUTDOWN_REQUESTED: LOGGER.debug("Shutting down (id=%s)", self.id) # Clear any unused session files in upload file manager and media # file manager self._uploaded_file_mgr.remove_session_files(self.id) media_file_manager.clear_session_files(self.id) media_file_manager.del_expired_files() # Shut down the ScriptRunner, if one is active. # self._state must not be set to SHUTDOWN_REQUESTED until # after this is called. if self._scriptrunner is not None: self._enqueue_script_request(ScriptRequest.SHUTDOWN) self._state = ReportSessionState.SHUTDOWN_REQUESTED self._local_sources_watcher.close() def enqueue(self, msg): """Enqueue a new ForwardMsg to our browser queue. This can be called on both the main thread and a ScriptRunner run thread. Parameters ---------- msg : ForwardMsg The message to enqueue """ if not config.get_option("client.displayEnabled"): return # Avoid having two maybe_handle_execution_control_request running on # top of each other when tracer is installed. This leads to a lock # contention. if not config.get_option("runner.installTracer"): # If we have an active ScriptRunner, signal that it can handle an # execution control request. (Copy the scriptrunner reference to # avoid it being unset from underneath us, as this function can be # called outside the main thread.) scriptrunner = self._scriptrunner if scriptrunner is not None: scriptrunner.maybe_handle_execution_control_request() self._report.enqueue(msg) def enqueue_exception(self, e): """Enqueue an Exception message. Parameters ---------- e : BaseException """ # This does a few things: # 1) Clears the current report in the browser. # 2) Marks the current report as "stopped" in the browser. # 3) HACK: Resets any script params that may have been broken (e.g. the # command-line when rerunning with wrong argv[0]) self._on_scriptrunner_event( ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS) self._on_scriptrunner_event(ScriptRunnerEvent.SCRIPT_STARTED) self._on_scriptrunner_event( ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS) msg = ForwardMsg() msg.metadata.delta_id = 0 exception_proto.marshall(msg.delta.new_element.exception, e) self.enqueue(msg) def request_rerun(self, client_state=None): """Signal that we're interested in running the script. If the script is not already running, it will be started immediately. Otherwise, a rerun will be requested. Parameters ---------- client_state : streamlit.proto.ClientState_pb2.ClientState | None The ClientState protobuf to run the script with, or None to use previous client state. """ if client_state: rerun_data = RerunData(client_state.query_string, client_state.widget_states) else: rerun_data = RerunData() self._enqueue_script_request(ScriptRequest.RERUN, rerun_data) self._set_page_config_allowed = True def _on_source_file_changed(self): """One of our source files changed. Schedule a rerun if appropriate.""" if self._run_on_save: self.request_rerun() else: self._enqueue_file_change_message() def _clear_queue(self): self._report.clear() def _on_scriptrunner_event(self, event, exception=None, client_state=None): """Called when our ScriptRunner emits an event. This is *not* called on the main thread. Parameters ---------- event : ScriptRunnerEvent exception : BaseException | None An exception thrown during compilation. Set only for the SCRIPT_STOPPED_WITH_COMPILE_ERROR event. client_state : streamlit.proto.ClientState_pb2.ClientState | None The ScriptRunner's final ClientState. Set only for the SHUTDOWN event. """ LOGGER.debug("OnScriptRunnerEvent: %s", event) prev_state = self._state if event == ScriptRunnerEvent.SCRIPT_STARTED: if self._state != ReportSessionState.SHUTDOWN_REQUESTED: self._state = ReportSessionState.REPORT_IS_RUNNING if config.get_option("server.liveSave"): # Enqueue into the IOLoop so it runs without blocking AND runs # on the main thread. self._ioloop.spawn_callback(self._save_running_report) self._clear_queue() self._maybe_enqueue_initialize_message() self._enqueue_new_report_message() elif (event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS or event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_COMPILE_ERROR): if self._state != ReportSessionState.SHUTDOWN_REQUESTED: self._state = ReportSessionState.REPORT_NOT_RUNNING script_succeeded = event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS self._enqueue_report_finished_message( ForwardMsg.FINISHED_SUCCESSFULLY if script_succeeded else ForwardMsg.FINISHED_WITH_COMPILE_ERROR) if config.get_option("server.liveSave"): # Enqueue into the IOLoop so it runs without blocking AND runs # on the main thread. self._ioloop.spawn_callback(self._save_final_report_and_quit) if script_succeeded: # When a script completes successfully, we update our # LocalSourcesWatcher to account for any source code changes # that change which modules should be watched. (This is run on # the main thread, because LocalSourcesWatcher is not # thread safe.) self._ioloop.spawn_callback( self._local_sources_watcher.update_watched_modules) else: # When a script fails to compile, we send along the exception. from streamlit.elements import exception_proto msg = ForwardMsg() exception_proto.marshall( msg.session_event.script_compilation_exception, exception) self.enqueue(msg) elif event == ScriptRunnerEvent.SHUTDOWN: # When ScriptRunner shuts down, update our local reference to it, # and check to see if we need to spawn a new one. (This is run on # the main thread.) if self._state == ReportSessionState.SHUTDOWN_REQUESTED: # Only clear media files if the script is done running AND the # report session is actually shutting down. media_file_manager.clear_session_files(self.id) def on_shutdown(): self._client_state = client_state self._scriptrunner = None # Because a new ScriptEvent could have been enqueued while the # scriptrunner was shutting down, we check to see if we should # create a new one. (Otherwise, a newly-enqueued ScriptEvent # won't be processed until another event is enqueued.) self._maybe_create_scriptrunner() self._ioloop.spawn_callback(on_shutdown) # Send a message if our run state changed report_was_running = prev_state == ReportSessionState.REPORT_IS_RUNNING report_is_running = self._state == ReportSessionState.REPORT_IS_RUNNING if report_is_running != report_was_running: self._enqueue_session_state_changed_message() def _enqueue_session_state_changed_message(self): msg = ForwardMsg() msg.session_state_changed.run_on_save = self._run_on_save msg.session_state_changed.report_is_running = ( self._state == ReportSessionState.REPORT_IS_RUNNING) self.enqueue(msg) def _enqueue_file_change_message(self): LOGGER.debug("Enqueuing report_changed message (id=%s)", self.id) msg = ForwardMsg() msg.session_event.report_changed_on_disk = True self.enqueue(msg) def _maybe_enqueue_initialize_message(self): if self._sent_initialize_message: return self._sent_initialize_message = True msg = ForwardMsg() imsg = msg.initialize imsg.config.sharing_enabled = config.get_option( "global.sharingMode") != "off" imsg.config.gather_usage_stats = config.get_option( "browser.gatherUsageStats") imsg.config.max_cached_message_age = config.get_option( "global.maxCachedMessageAge") imsg.config.mapbox_token = config.get_option("mapbox.token") LOGGER.debug( "New browser connection: " "gather_usage_stats=%s, " "sharing_enabled=%s, " "max_cached_message_age=%s", imsg.config.gather_usage_stats, imsg.config.sharing_enabled, imsg.config.max_cached_message_age, ) imsg.environment_info.streamlit_version = __version__ imsg.environment_info.python_version = ".".join( map(str, sys.version_info)) imsg.session_state.run_on_save = self._run_on_save imsg.session_state.report_is_running = ( self._state == ReportSessionState.REPORT_IS_RUNNING) imsg.user_info.installation_id = Installation.instance( ).installation_id imsg.user_info.installation_id_v1 = Installation.instance( ).installation_id_v1 imsg.user_info.installation_id_v2 = Installation.instance( ).installation_id_v2 if Credentials.get_current().activation: imsg.user_info.email = Credentials.get_current().activation.email else: imsg.user_info.email = "" imsg.command_line = self._report.command_line imsg.session_id = self.id self.enqueue(msg) def _enqueue_new_report_message(self): self._report.generate_new_id() msg = ForwardMsg() msg.new_report.id = self._report.report_id msg.new_report.name = self._report.name msg.new_report.script_path = self._report.script_path self.enqueue(msg) def _enqueue_report_finished_message(self, status): """Enqueue a report_finished ForwardMsg. Parameters ---------- status : ReportFinishedStatus """ msg = ForwardMsg() msg.report_finished = status self.enqueue(msg) def handle_rerun_script_request(self, client_state=None, is_preheat=False): """Tell the ScriptRunner to re-run its report. Parameters ---------- client_state : streamlit.proto.ClientState_pb2.ClientState | None The ClientState protobuf to run the script with, or None to use previous client state. is_preheat: boolean True if this ReportSession should run the script immediately, and then ignore the next rerun request if it matches the already-ran widget state. """ if is_preheat: self._maybe_reuse_previous_run = True # For next time. elif self._maybe_reuse_previous_run: # If this is a "preheated" ReportSession, reuse the previous run if # the widget state matches. But only do this one time ever. self._maybe_reuse_previous_run = False has_client_state = False if client_state is not None: has_query_string = client_state.query_string != "" has_widget_states = ( client_state.widget_states is not None and len(client_state.widget_states.widgets) > 0) has_client_state = has_query_string or has_widget_states if not has_client_state: LOGGER.debug( "Skipping rerun since the preheated run is the same") return self.request_rerun(client_state) def handle_stop_script_request(self): """Tell the ScriptRunner to stop running its report.""" self._enqueue_script_request(ScriptRequest.STOP) def handle_clear_cache_request(self): """Clear this report's cache. Because this cache is global, it will be cleared for all users. """ # Setting verbose=True causes clear_cache to print to stdout. # Since this command was initiated from the browser, the user # doesn't need to see the results of the command in their # terminal. caching.clear_cache() def handle_set_run_on_save_request(self, new_value): """Change our run_on_save flag to the given value. The browser will be notified of the change. Parameters ---------- new_value : bool New run_on_save value """ self._run_on_save = new_value self._enqueue_session_state_changed_message() def _enqueue_script_request(self, request, data=None): """Enqueue a ScriptEvent into our ScriptEventQueue. If a script thread is not already running, one will be created to handle the event. Parameters ---------- request : ScriptRequest The type of request. data : Any Data associated with the request, if any. """ if self._state == ReportSessionState.SHUTDOWN_REQUESTED: LOGGER.warning("Discarding %s request after shutdown" % request) return self._script_request_queue.enqueue(request, data) self._maybe_create_scriptrunner() def _maybe_create_scriptrunner(self): """Create a new ScriptRunner if we have unprocessed script requests. This is called every time a ScriptRequest is enqueued, and also after a ScriptRunner shuts down, in case new requests were enqueued during its termination. This function should only be called on the main thread. """ if (self._state == ReportSessionState.SHUTDOWN_REQUESTED or self._scriptrunner is not None or not self._script_request_queue.has_request): return # Create the ScriptRunner, attach event handlers, and start it self._scriptrunner = ScriptRunner( session_id=self.id, report=self._report, enqueue_forward_msg=self.enqueue, client_state=self._client_state, request_queue=self._script_request_queue, uploaded_file_mgr=self._uploaded_file_mgr, ) self._scriptrunner.on_event.connect(self._on_scriptrunner_event) self._scriptrunner.start() @tornado.gen.coroutine def handle_save_request(self, ws): """Save serialized version of report deltas to the cloud. "Progress" ForwardMsgs will be sent to the client during the upload. These messages are sent "out of band" - that is, they don't get enqueued into the ReportQueue (because they're not part of the report). Instead, they're written directly to the report's WebSocket. Parameters ---------- ws : _BrowserWebSocketHandler The report's websocket handler. """ @tornado.gen.coroutine def progress(percent): progress_msg = ForwardMsg() progress_msg.upload_report_progress = percent yield ws.write_message(serialize_forward_msg(progress_msg), binary=True) # Indicate that the save is starting. try: yield progress(0) url = yield self._save_final_report(progress) # Indicate that the save is done. progress_msg = ForwardMsg() progress_msg.report_uploaded = url yield ws.write_message(serialize_forward_msg(progress_msg), binary=True) except Exception as e: # Horrible hack to show something if something breaks. err_msg = "%s: %s" % (type(e).__name__, str(e) or "No further details.") progress_msg = ForwardMsg() progress_msg.report_uploaded = err_msg yield ws.write_message(serialize_forward_msg(progress_msg), binary=True) LOGGER.warning("Failed to save report:", exc_info=e) @tornado.gen.coroutine def _save_running_report(self): files = self._report.serialize_running_report_to_files() url = yield self._get_storage().save_report_files( self._report.report_id, files) if config.get_option("server.liveSave"): url_util.print_url("Saved running app", url) raise tornado.gen.Return(url) @tornado.gen.coroutine def _save_final_report(self, progress_coroutine=None): files = self._report.serialize_final_report_to_files() url = yield self._get_storage().save_report_files( self._report.report_id, files, progress_coroutine) if config.get_option("server.liveSave"): url_util.print_url("Saved final app", url) raise tornado.gen.Return(url) @tornado.gen.coroutine def _save_final_report_and_quit(self): yield self._save_final_report() self._ioloop.stop() def _get_storage(self): if self._storage is None: sharing_mode = config.get_option("global.sharingMode") if sharing_mode == "s3": self._storage = S3Storage() elif sharing_mode == "file": self._storage = FileStorage() else: raise RuntimeError("Unsupported sharing mode '%s'" % sharing_mode) return self._storage
def test_rerun_data_coalescing(self): """Test that multiple RERUN requests get coalesced with expected values. (This is similar to widgets_test.test_coalesce_widget_states - it's testing the same thing, but through the ScriptEventQueue interface.) """ queue = ScriptRequestQueue() session_state = SessionState() states = WidgetStates() _create_widget("trigger", states).trigger_value = True _create_widget("int", states).int_value = 123 queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=states)) states = WidgetStates() _create_widget("trigger", states).trigger_value = False _create_widget("int", states).int_value = 456 session_state.set_metadata( WidgetMetadata("trigger", lambda x, s: x, None, "trigger_value")) session_state.set_metadata( WidgetMetadata("int", lambda x, s: x, lambda x: x, "int_value")) queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=states)) event, data = queue.dequeue() self.assertEqual(event, ScriptRequest.RERUN) session_state.set_widgets_from_proto(data.widget_states) # Coalesced triggers should be True if either the old or # new value was True self.assertEqual(True, session_state.get("trigger")) # Other widgets should have their newest value self.assertEqual(456, session_state.get("int")) # We should have no more events self.assertEqual((None, None), queue.dequeue(), "Expected empty event queue") # Test that we can coalesce if previous widget state is None queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=None)) queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=None)) states = WidgetStates() _create_widget("int", states).int_value = 789 queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=states)) event, data = queue.dequeue() session_state.set_widgets_from_proto(data.widget_states) self.assertEqual(event, ScriptRequest.RERUN) self.assertEqual(789, session_state.get("int")) # We should have no more events self.assertEqual((None, None), queue.dequeue(), "Expected empty event queue") # Test that we can coalesce if our *new* widget state is None states = WidgetStates() _create_widget("int", states).int_value = 101112 queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=states)) queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=None)) event, data = queue.dequeue() session_state.set_widgets_from_proto(data.widget_states) self.assertEqual(event, ScriptRequest.RERUN) self.assertEqual(101112, session_state.get("int")) # We should have no more events self.assertEqual((None, None), queue.dequeue(), "Expected empty event queue")
class AppSession: """ Contains session data for a single "user" of an active app (that is, a connected browser tab). Each AppSession has its own SessionData, root DeltaGenerator, ScriptRunner, and widget state. An AppSession is attached to each thread involved in running its script. """ def __init__( self, ioloop: tornado.ioloop.IOLoop, session_data: SessionData, uploaded_file_manager: UploadedFileManager, message_enqueued_callback: Optional[Callable[[], None]], local_sources_watcher: LocalSourcesWatcher, ): """Initialize the AppSession. Parameters ---------- ioloop : tornado.ioloop.IOLoop The Tornado IOLoop that we're running within. session_data : SessionData Object storing parameters related to running a script uploaded_file_manager : UploadedFileManager The server's UploadedFileManager. message_enqueued_callback : Callable[[], None] After enqueuing a message, this callable notification will be invoked. local_sources_watcher: LocalSourcesWatcher The file watcher that lets the session know local files have changed. """ # Each AppSession has a unique string ID. self.id = str(uuid.uuid4()) self._ioloop = ioloop self._session_data = session_data self._uploaded_file_mgr = uploaded_file_manager self._message_enqueued_callback = message_enqueued_callback self._state = AppSessionState.APP_NOT_RUNNING # Need to remember the client state here because when a script reruns # due to the source code changing we need to pass in the previous client state. self._client_state = ClientState() self._local_sources_watcher = local_sources_watcher self._local_sources_watcher.register_file_change_callback( self._on_source_file_changed) self._stop_config_listener = config.on_config_parsed( self._on_source_file_changed, force_connect=True) # The script should rerun when the `secrets.toml` file has been changed. secrets._file_change_listener.connect(self._on_secrets_file_changed) self._run_on_save = config.get_option("server.runOnSave") # The ScriptRequestQueue is the means by which we communicate # with the active ScriptRunner. self._script_request_queue = ScriptRequestQueue() self._scriptrunner: Optional[ScriptRunner] = None # This needs to be lazily imported to avoid a dependency cycle. from streamlit.state.session_state import SessionState self._session_state = SessionState() LOGGER.debug("AppSession initialized (id=%s)", self.id) def flush_browser_queue(self) -> List[ForwardMsg]: """Clear the forward message queue and return the messages it contained. The Server calls this periodically to deliver new messages to the browser connected to this app. Returns ------- list[ForwardMsg] The messages that were removed from the queue and should be delivered to the browser. """ return self._session_data.flush_browser_queue() def shutdown(self) -> None: """Shut down the AppSession. It's an error to use a AppSession after it's been shut down. """ if self._state != AppSessionState.SHUTDOWN_REQUESTED: LOGGER.debug("Shutting down (id=%s)", self.id) # Clear any unused session files in upload file manager and media # file manager self._uploaded_file_mgr.remove_session_files(self.id) in_memory_file_manager.clear_session_files(self.id) in_memory_file_manager.del_expired_files() # Shut down the ScriptRunner, if one is active. # self._state must not be set to SHUTDOWN_REQUESTED until # after this is called. if self._scriptrunner is not None: self._enqueue_script_request(ScriptRequest.SHUTDOWN) self._state = AppSessionState.SHUTDOWN_REQUESTED self._local_sources_watcher.close() if self._stop_config_listener is not None: self._stop_config_listener() secrets._file_change_listener.disconnect( self._on_secrets_file_changed) def enqueue(self, msg: ForwardMsg) -> None: """Enqueue a new ForwardMsg to our browser queue. This can be called on both the main thread and a ScriptRunner run thread. Parameters ---------- msg : ForwardMsg The message to enqueue """ if not config.get_option("client.displayEnabled"): return # Avoid having two maybe_handle_execution_control_request running on # top of each other when tracer is installed. This leads to a lock # contention. if not config.get_option("runner.installTracer"): # If we have an active ScriptRunner, signal that it can handle an # execution control request. (Copy the scriptrunner reference to # avoid it being unset from underneath us, as this function can be # called outside the main thread.) scriptrunner = self._scriptrunner if scriptrunner is not None: scriptrunner.maybe_handle_execution_control_request() self._session_data.enqueue(msg) if self._message_enqueued_callback: self._message_enqueued_callback() def enqueue_exception(self, e: BaseException) -> None: """Enqueue an Exception message.""" # This does a few things: # 1) Clears the current app in the browser. # 2) Marks the current app as "stopped" in the browser. # 3) HACK: Resets any script params that may have been broken (e.g. the # command-line when rerunning with wrong argv[0]) self._on_scriptrunner_event( ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS) self._on_scriptrunner_event(ScriptRunnerEvent.SCRIPT_STARTED) self._on_scriptrunner_event( ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS) msg = ForwardMsg() exception_utils.marshall(msg.delta.new_element.exception, e) self.enqueue(msg) def request_rerun(self, client_state: Optional[ClientState]) -> None: """Signal that we're interested in running the script. If the script is not already running, it will be started immediately. Otherwise, a rerun will be requested. Parameters ---------- client_state : streamlit.proto.ClientState_pb2.ClientState | None The ClientState protobuf to run the script with, or None to use previous client state. """ if client_state: rerun_data = RerunData(client_state.query_string, client_state.widget_states) else: rerun_data = RerunData() self._enqueue_script_request(ScriptRequest.RERUN, rerun_data) @property def session_state(self) -> "SessionState": return self._session_state def _on_source_file_changed(self) -> None: """One of our source files changed. Schedule a rerun if appropriate.""" if self._run_on_save: self.request_rerun(self._client_state) else: self._enqueue_file_change_message() def _on_secrets_file_changed(self, _) -> None: """Called when `secrets._file_change_listener` emits a Signal.""" # NOTE: At the time of writing, this function only calls `_on_source_file_changed`. # The reason behind creating this function instead of just passing `_on_source_file_changed` # to `connect` / `disconnect` directly is that every function that is passed to `connect` / `disconnect` # must have at least one argument for `sender` (in this case we don't really care about it, thus `_`), # and introducing an unnecessary argument to `_on_source_file_changed` just for this purpose sounded finicky. self._on_source_file_changed() def _clear_queue(self) -> None: self._session_data.clear() def _on_scriptrunner_event( self, event: ScriptRunnerEvent, exception: Optional[BaseException] = None, client_state: Optional[ClientState] = None, ) -> None: """Called when our ScriptRunner emits an event. This is *not* called on the main thread. Parameters ---------- event : ScriptRunnerEvent exception : BaseException | None An exception thrown during compilation. Set only for the SCRIPT_STOPPED_WITH_COMPILE_ERROR event. client_state : streamlit.proto.ClientState_pb2.ClientState | None The ScriptRunner's final ClientState. Set only for the SHUTDOWN event. """ LOGGER.debug("OnScriptRunnerEvent: %s", event) prev_state = self._state if event == ScriptRunnerEvent.SCRIPT_STARTED: if self._state != AppSessionState.SHUTDOWN_REQUESTED: self._state = AppSessionState.APP_IS_RUNNING self._clear_queue() self._enqueue_new_session_message() elif (event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS or event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_COMPILE_ERROR): if self._state != AppSessionState.SHUTDOWN_REQUESTED: self._state = AppSessionState.APP_NOT_RUNNING script_succeeded = event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS self._enqueue_script_finished_message( ForwardMsg.FINISHED_SUCCESSFULLY if script_succeeded else ForwardMsg.FINISHED_WITH_COMPILE_ERROR) if script_succeeded: # When a script completes successfully, we update our # LocalSourcesWatcher to account for any source code changes # that change which modules should be watched. (This is run on # the main thread, because LocalSourcesWatcher is not # thread safe.) self._ioloop.spawn_callback( self._local_sources_watcher.update_watched_modules) else: msg = ForwardMsg() exception_utils.marshall( msg.session_event.script_compilation_exception, exception) self.enqueue(msg) elif event == ScriptRunnerEvent.SHUTDOWN: # When ScriptRunner shuts down, update our local reference to it, # and check to see if we need to spawn a new one. (This is run on # the main thread.) assert ( client_state is not None), "client_state must be set for the SHUTDOWN event" if self._state == AppSessionState.SHUTDOWN_REQUESTED: # Only clear media files if the script is done running AND the # session is actually shutting down. in_memory_file_manager.clear_session_files(self.id) def on_shutdown(): # We assert above that this is non-null self._client_state = cast(ClientState, client_state) self._scriptrunner = None # Because a new ScriptEvent could have been enqueued while the # scriptrunner was shutting down, we check to see if we should # create a new one. (Otherwise, a newly-enqueued ScriptEvent # won't be processed until another event is enqueued.) self._maybe_create_scriptrunner() self._ioloop.spawn_callback(on_shutdown) # Send a message if our run state changed app_was_running = prev_state == AppSessionState.APP_IS_RUNNING app_is_running = self._state == AppSessionState.APP_IS_RUNNING if app_is_running != app_was_running: self._enqueue_session_state_changed_message() def _enqueue_session_state_changed_message(self) -> None: msg = ForwardMsg() msg.session_state_changed.run_on_save = self._run_on_save msg.session_state_changed.script_is_running = ( self._state == AppSessionState.APP_IS_RUNNING) self.enqueue(msg) def _enqueue_file_change_message(self) -> None: LOGGER.debug("Enqueuing script_changed message (id=%s)", self.id) msg = ForwardMsg() msg.session_event.script_changed_on_disk = True self.enqueue(msg) def _enqueue_new_session_message(self) -> None: msg = ForwardMsg() msg.new_session.script_run_id = _generate_scriptrun_id() msg.new_session.name = self._session_data.name msg.new_session.script_path = self._session_data.script_path _populate_config_msg(msg.new_session.config) _populate_theme_msg(msg.new_session.custom_theme) # Immutable session data. We send this every time a new session is # started, to avoid having to track whether the client has already # received it. It does not change from run to run; it's up to the # to perform one-time initialization only once. imsg = msg.new_session.initialize _populate_user_info_msg(imsg.user_info) imsg.environment_info.streamlit_version = __version__ imsg.environment_info.python_version = ".".join( map(str, sys.version_info)) imsg.session_state.run_on_save = self._run_on_save imsg.session_state.script_is_running = ( self._state == AppSessionState.APP_IS_RUNNING) imsg.command_line = self._session_data.command_line imsg.session_id = self.id self.enqueue(msg) def _enqueue_script_finished_message( self, status: "ForwardMsg.ScriptFinishedStatus.ValueType") -> None: """Enqueue a script_finished ForwardMsg.""" msg = ForwardMsg() msg.script_finished = status self.enqueue(msg) def handle_git_information_request(self) -> None: msg = ForwardMsg() try: from streamlit.git_util import GitRepo repo = GitRepo(self._session_data.script_path) repo_info = repo.get_repo_info() if repo_info is None: return repository_name, branch, module = repo_info msg.git_info_changed.repository = repository_name msg.git_info_changed.branch = branch msg.git_info_changed.module = module msg.git_info_changed.untracked_files[:] = repo.untracked_files msg.git_info_changed.uncommitted_files[:] = repo.uncommitted_files if repo.is_head_detached: msg.git_info_changed.state = GitInfo.GitStates.HEAD_DETACHED elif len(repo.ahead_commits) > 0: msg.git_info_changed.state = GitInfo.GitStates.AHEAD_OF_REMOTE else: msg.git_info_changed.state = GitInfo.GitStates.DEFAULT self.enqueue(msg) except Exception as e: # Users may never even install Git in the first place, so this # error requires no action. It can be useful for debugging. LOGGER.debug("Obtaining Git information produced an error", exc_info=e) def handle_rerun_script_request(self, client_state: Optional[ClientState] = None ) -> None: """Tell the ScriptRunner to re-run its script. Parameters ---------- client_state : streamlit.proto.ClientState_pb2.ClientState | None The ClientState protobuf to run the script with, or None to use previous client state. """ self.request_rerun(client_state) def handle_stop_script_request(self) -> None: """Tell the ScriptRunner to stop running its script.""" self._enqueue_script_request(ScriptRequest.STOP) def handle_clear_cache_request(self) -> None: """Clear this app's cache. Because this cache is global, it will be cleared for all users. """ legacy_caching.clear_cache() caching.memo.clear() caching.singleton.clear() self._session_state.clear_state() def handle_set_run_on_save_request(self, new_value: bool) -> None: """Change our run_on_save flag to the given value. The browser will be notified of the change. Parameters ---------- new_value : bool New run_on_save value """ self._run_on_save = new_value self._enqueue_session_state_changed_message() def _enqueue_script_request(self, request: ScriptRequest, data: Any = None) -> None: """Enqueue a ScriptEvent into our ScriptEventQueue. If a script thread is not already running, one will be created to handle the event. Parameters ---------- request : ScriptRequest The type of request. data : Any Data associated with the request, if any. """ if self._state == AppSessionState.SHUTDOWN_REQUESTED: LOGGER.warning("Discarding %s request after shutdown" % request) return self._script_request_queue.enqueue(request, data) self._maybe_create_scriptrunner() def _maybe_create_scriptrunner(self) -> None: """Create a new ScriptRunner if we have unprocessed script requests. This is called every time a ScriptRequest is enqueued, and also after a ScriptRunner shuts down, in case new requests were enqueued during its termination. This function should only be called on the main thread. """ if (self._state == AppSessionState.SHUTDOWN_REQUESTED or self._scriptrunner is not None or not self._script_request_queue.has_request): return # Create the ScriptRunner, attach event handlers, and start it self._scriptrunner = ScriptRunner( session_id=self.id, session_data=self._session_data, enqueue_forward_msg=self.enqueue, client_state=self._client_state, request_queue=self._script_request_queue, session_state=self._session_state, uploaded_file_mgr=self._uploaded_file_mgr, ) self._scriptrunner.on_event.connect(self._on_scriptrunner_event) self._scriptrunner.start()
class TestScriptRunner(ScriptRunner): """Subclasses ScriptRunner to provide some testing features.""" def __init__(self, script_name): """Initializes the ScriptRunner for the given script_name""" # DeltaGenerator deltas will be enqueued into self.forward_msg_queue. self.forward_msg_queue = ForwardMsgQueue() self.script_request_queue = ScriptRequestQueue() main_script_path = os.path.join(os.path.dirname(__file__), "test_data", script_name) super(TestScriptRunner, self).__init__( session_id="test session id", session_data=SessionData(main_script_path, "test command line"), enqueue_forward_msg=self.forward_msg_queue.enqueue, client_state=ClientState(), session_state=SessionState(), request_queue=self.script_request_queue, uploaded_file_mgr=UploadedFileManager(), ) # Accumulates uncaught exceptions thrown by our run thread. self.script_thread_exceptions = [] # Accumulates all ScriptRunnerEvents emitted by us. self.events: List[ScriptRunnerEvent] = [] self.event_data: List[Any] = [] def record_event(sender: Optional[ScriptRunner], event: ScriptRunnerEvent, **kwargs): # Assert that we're not getting unexpected `sender` params # from ScriptRunner.on_event assert (sender is None or sender == self), "Unexpected ScriptRunnerEvent sender!" self.events.append(event) self.event_data.append(kwargs) self.on_event.connect(record_event, weak=False) def enqueue_rerun(self, argv=None, widget_states=None, query_string=""): self.script_request_queue.enqueue( ScriptRequest.RERUN, RerunData(widget_states=widget_states, query_string=query_string), ) def enqueue_stop(self): self.script_request_queue.enqueue(ScriptRequest.STOP) def enqueue_shutdown(self): self.script_request_queue.enqueue(ScriptRequest.SHUTDOWN) def _run_script_thread(self): try: super()._run_script_thread() except BaseException as e: self.script_thread_exceptions.append(e) def _run_script(self, rerun_data): self.forward_msg_queue.clear() super()._run_script(rerun_data) def join(self): """Joins the run thread, if it was started""" if self._script_thread is not None: self._script_thread.join() def clear_deltas(self): """Clear all delta messages from our ForwardMsgQueue""" self.forward_msg_queue.clear() def deltas(self) -> List[Delta]: """Return the delta messages in our ForwardMsgQueue""" return [ msg.delta for msg in self.forward_msg_queue._queue if msg.HasField("delta") ] def elements(self) -> List[Element]: """Return the delta.new_element messages in our ForwardMsgQueue.""" return [delta.new_element for delta in self.deltas()] def text_deltas(self) -> List[str]: """Return the string contents of text deltas in our ForwardMsgQueue""" return [ element.text.body for element in self.elements() if element.WhichOneof("type") == "text" ] def get_widget_id(self, widget_type, label): """Returns the id of the widget with the specified type and label""" for delta in self.deltas(): new_element = getattr(delta, "new_element", None) widget = getattr(new_element, widget_type, None) widget_label = getattr(widget, "label", None) if widget_label == label: return widget.id return None