def test_coalesce_widget_states(self): old_states = WidgetStates() _create_widget("old_set_trigger", old_states).trigger_value = True _create_widget("old_unset_trigger", old_states).trigger_value = False _create_widget("missing_in_new", old_states).int_value = 123 _create_widget("shape_changing_trigger", old_states).trigger_value = True new_states = WidgetStates() _create_widget("old_set_trigger", new_states).trigger_value = False _create_widget("new_set_trigger", new_states).trigger_value = True _create_widget("added_in_new", new_states).int_value = 456 _create_widget("shape_changing_trigger", new_states).int_value = 3 widgets = WidgetStateManager() widgets.set_state(coalesce_widget_states(old_states, new_states)) self.assertIsNone(widgets.get_widget_value("old_unset_trigger")) self.assertIsNone(widgets.get_widget_value("missing_in_new")) self.assertEqual(True, widgets.get_widget_value("old_set_trigger")) self.assertEqual(True, widgets.get_widget_value("new_set_trigger")) self.assertEqual(456, widgets.get_widget_value("added_in_new")) # Widgets that were triggers before, but no longer are, will *not* # be coalesced self.assertEqual(3, widgets.get_widget_value("shape_changing_trigger"))
def test_reset_triggers(self): states = WidgetStates() widgets = WidgetStateManager() _create_widget("trigger", states).trigger_value = True _create_widget("int", states).int_value = 123 widgets.set_state(states) self.assertEqual(True, widgets.get_widget_value("trigger")) self.assertEqual(123, widgets.get_widget_value("int")) widgets.reset_triggers() self.assertEqual(None, widgets.get_widget_value("trigger")) self.assertEqual(123, widgets.get_widget_value("int"))
def test_values(self): states = WidgetStates() _create_widget("trigger", states).trigger_value = True _create_widget("bool", states).bool_value = True _create_widget("float", states).double_value = 0.5 _create_widget("int", states).int_value = 123 _create_widget("string", states).string_value = "howdy!" widgets = WidgetStateManager() widgets.set_state(states) self.assertEqual(True, widgets.get_widget_value("trigger")) self.assertEqual(True, widgets.get_widget_value("bool")) self.assertAlmostEqual(0.5, widgets.get_widget_value("float")) self.assertEqual(123, widgets.get_widget_value("int")) self.assertEqual("howdy!", widgets.get_widget_value("string"))
def test_handle_save_request(self, _1): """Test that handle_save_request serializes files correctly.""" # Create a ReportSession with some mocked bits rs = ReportSession(self.io_loop, "mock_report.py", "", UploadedFileManager()) rs._report.report_id = "TestReportID" orig_ctx = get_report_ctx() ctx = ReportContext( "TestSessionID", rs._report.enqueue, "", WidgetStateManager(), UploadedFileManager(), ) add_report_ctx(ctx=ctx) rs._scriptrunner = MagicMock() storage = MockStorage() rs._storage = storage # Send two deltas: empty and markdown st.empty() st.markdown("Text!") yield rs.handle_save_request(_create_mock_websocket()) # Check the order of the received files. Manifest should be last. self.assertEqual(3, len(storage.files)) self.assertEqual("reports/TestReportID/0.pb", storage.get_filename(0)) self.assertEqual("reports/TestReportID/1.pb", storage.get_filename(1)) self.assertEqual("reports/TestReportID/manifest.pb", storage.get_filename(2)) # Check the manifest manifest = storage.get_message(2, StaticManifest) self.assertEqual("mock_report", manifest.name) self.assertEqual(2, manifest.num_messages) self.assertEqual(StaticManifest.DONE, manifest.server_status) # Check that the deltas we sent match messages in storage sent_messages = rs._report._master_queue._queue received_messages = [ storage.get_message(0, ForwardMsg), storage.get_message(1, ForwardMsg), ] self.assertEqual(sent_messages, received_messages) add_report_ctx(ctx=orig_ctx)
def setUp(self, override_root=True): self.report_queue = ReportQueue() self.override_root = override_root self.orig_report_ctx = None if self.override_root: self.orig_report_ctx = get_report_ctx() add_report_ctx( threading.current_thread(), ReportContext( session_id="test session id", enqueue=self.report_queue.enqueue, query_string="", widgets=WidgetStateManager(), uploaded_file_mgr=UploadedFileManager(), ), )
def test_set_page_config_immutable(self): """st.set_page_config must be called at most once""" fake_enqueue = lambda msg: None ctx = ReportContext( "TestSessionID", fake_enqueue, "", WidgetStateManager(), UploadedFileManager(), ) msg = ForwardMsg() msg.page_config_changed.title = "foo" ctx.enqueue(msg) with self.assertRaises(StreamlitAPIException): ctx.enqueue(msg)
def test_set_page_config_reset(self): """st.set_page_config should be allowed after a rerun""" fake_enqueue = lambda msg: None ctx = ReportContext( "TestSessionID", fake_enqueue, "", WidgetStateManager(), UploadedFileManager(), ) msg = ForwardMsg() msg.page_config_changed.title = "foo" ctx.enqueue(msg) ctx.reset() try: ctx.enqueue(msg) except StreamlitAPIException: self.fail("set_page_config should have succeeded after reset!")
def test_set_page_config_first(self): """st.set_page_config must be called before other st commands""" fake_enqueue = lambda msg: None ctx = ReportContext( "TestSessionID", fake_enqueue, "", WidgetStateManager(), UploadedFileManager(), ) markdown_msg = ForwardMsg() markdown_msg.delta.new_element.markdown.body = "foo" msg = ForwardMsg() msg.page_config_changed.title = "foo" ctx.enqueue(markdown_msg) with self.assertRaises(StreamlitAPIException): ctx.enqueue(msg)
def __init__( self, session_id, report, enqueue_forward_msg, client_state, request_queue, uploaded_file_mgr=None, ): """Initialize the ScriptRunner. (The ScriptRunner won't start executing until start() is called.) Parameters ---------- session_id : str The ReportSession's id. report : Report The ReportSession's report. client_state : streamlit.proto.ClientState_pb2.ClientState The current state from the client (widgets and query params). request_queue : ScriptRequestQueue The queue that the ReportSession is publishing ScriptRequests to. ScriptRunner will continue running until the queue is empty, and then shut down. uploaded_file_mgr : UploadedFileManager The File manager to store the data uploaded by the file_uploader widget. """ self._session_id = session_id self._report = report self._enqueue_forward_msg = enqueue_forward_msg self._request_queue = request_queue self._uploaded_file_mgr = uploaded_file_mgr self._client_state = client_state self._widgets = WidgetStateManager() self._widgets.set_state(client_state.widget_states) self.on_event = Signal(doc="""Emitted when a ScriptRunnerEvent occurs. This signal is *not* emitted on the same thread that the ScriptRunner was created on. Parameters ---------- event : ScriptRunnerEvent exception : BaseException | None Our compile error. Set only for the SCRIPT_STOPPED_WITH_COMPILE_ERROR event. widget_states : streamlit.proto.WidgetStates_pb2.WidgetStates | None The ScriptRunner's final WidgetStates. Set only for the SHUTDOWN event. """) # Set to true when we process a SHUTDOWN request self._shutdown_requested = False # Set to true while we're executing. Used by # maybe_handle_execution_control_request. self._execing = False # This is initialized in start() self._script_thread = None
class ScriptRunner(object): def __init__( self, session_id, report, enqueue_forward_msg, client_state, request_queue, uploaded_file_mgr=None, ): """Initialize the ScriptRunner. (The ScriptRunner won't start executing until start() is called.) Parameters ---------- session_id : str The ReportSession's id. report : Report The ReportSession's report. client_state : streamlit.proto.ClientState_pb2.ClientState The current state from the client (widgets and query params). request_queue : ScriptRequestQueue The queue that the ReportSession is publishing ScriptRequests to. ScriptRunner will continue running until the queue is empty, and then shut down. uploaded_file_mgr : UploadedFileManager The File manager to store the data uploaded by the file_uploader widget. """ self._session_id = session_id self._report = report self._enqueue_forward_msg = enqueue_forward_msg self._request_queue = request_queue self._uploaded_file_mgr = uploaded_file_mgr self._client_state = client_state self._widgets = WidgetStateManager() self._widgets.set_state(client_state.widget_states) self.on_event = Signal(doc="""Emitted when a ScriptRunnerEvent occurs. This signal is *not* emitted on the same thread that the ScriptRunner was created on. Parameters ---------- event : ScriptRunnerEvent exception : BaseException | None Our compile error. Set only for the SCRIPT_STOPPED_WITH_COMPILE_ERROR event. widget_states : streamlit.proto.WidgetStates_pb2.WidgetStates | None The ScriptRunner's final WidgetStates. Set only for the SHUTDOWN event. """) # Set to true when we process a SHUTDOWN request self._shutdown_requested = False # Set to true while we're executing. Used by # maybe_handle_execution_control_request. self._execing = False # This is initialized in start() self._script_thread = None def __repr__(self) -> str: return util.repr_(self) def start(self): """Start a new thread to process the ScriptEventQueue. This must be called only once. """ if self._script_thread is not None: raise Exception("ScriptRunner was already started") self._script_thread = ReportThread( session_id=self._session_id, enqueue=self._enqueue_forward_msg, query_string=self._client_state.query_string, widgets=self._widgets, uploaded_file_mgr=self._uploaded_file_mgr, target=self._process_request_queue, name="ScriptRunner.scriptThread", ) self._script_thread.start() def _process_request_queue(self): """Process the ScriptRequestQueue and then exits. This is run in a separate thread. """ LOGGER.debug("Beginning script thread") while not self._shutdown_requested and self._request_queue.has_request: request, data = self._request_queue.dequeue() if request == ScriptRequest.STOP: LOGGER.debug("Ignoring STOP request while not running") elif request == ScriptRequest.SHUTDOWN: LOGGER.debug("Shutting down") self._shutdown_requested = True elif request == ScriptRequest.RERUN: self._run_script(data) else: raise RuntimeError("Unrecognized ScriptRequest: %s" % request) # Send a SHUTDOWN event before exiting. This includes the widget values # as they existed after our last successful script run, which the # ReportSession will pass on to the next ScriptRunner that gets # created. client_state = ClientState() client_state.query_string = self._client_state.query_string self._widgets.marshall(client_state) self.on_event.send(ScriptRunnerEvent.SHUTDOWN, client_state=client_state) def _is_in_script_thread(self): """True if the calling function is running in the script thread""" return self._script_thread == threading.current_thread() def maybe_handle_execution_control_request(self): if not self._is_in_script_thread(): # We can only handle execution_control_request if we're on the # script execution thread. However, it's possible for deltas to # be enqueued (and, therefore, for this function to be called) # in separate threads, so we check for that here. return if not self._execing: # If the _execing flag is not set, we're not actually inside # an exec() call. This happens when our script exec() completes, # we change our state to STOPPED, and a statechange-listener # enqueues a new ForwardEvent return # Pop the next request from our queue. request, data = self._request_queue.dequeue() if request is None: return LOGGER.debug("Received ScriptRequest: %s", request) if request == ScriptRequest.STOP: raise StopException() elif request == ScriptRequest.SHUTDOWN: self._shutdown_requested = True raise StopException() elif request == ScriptRequest.RERUN: raise RerunException(data) else: raise RuntimeError("Unrecognized ScriptRequest: %s" % request) def _install_tracer(self): """Install function that runs before each line of the script.""" def trace_calls(frame, event, arg): self.maybe_handle_execution_control_request() return trace_calls # Python interpreters are not required to implement sys.settrace. if hasattr(sys, "settrace"): sys.settrace(trace_calls) @contextmanager def _set_execing_flag(self): """A context for setting the ScriptRunner._execing flag. Used by maybe_handle_execution_control_request to ensure that we only handle requests while we're inside an exec() call """ if self._execing: raise RuntimeError("Nested set_execing_flag call") self._execing = True try: yield finally: self._execing = False def _run_script(self, rerun_data): """Run our script. Parameters ---------- rerun_data: RerunData The RerunData to use. """ assert self._is_in_script_thread() LOGGER.debug("Running script %s", rerun_data) # Reset DeltaGenerators, widgets, media files. media_file_manager.clear_session_files() ctx = get_report_ctx() if ctx is None: # This should never be possible on the script_runner thread. raise RuntimeError( "ScriptRunner thread has a null ReportContext. Something has gone very wrong!" ) ctx.reset(query_string=rerun_data.query_string) self.on_event.send(ScriptRunnerEvent.SCRIPT_STARTED) # Compile the script. Any errors thrown here will be surfaced # to the user via a modal dialog in the frontend, and won't result # in their previous report disappearing. try: with source_util.open_python_file(self._report.script_path) as f: filebody = f.read() if config.get_option("runner.magicEnabled"): filebody = magic.add_magic(filebody, self._report.script_path) code = compile( filebody, # Pass in the file path so it can show up in exceptions. self._report.script_path, # We're compiling entire blocks of Python, so we need "exec" # mode (as opposed to "eval" or "single"). mode="exec", # Don't inherit any flags or "future" statements. flags=0, dont_inherit=1, # Use the default optimization options. optimize=-1, ) except BaseException as e: # We got a compile error. Send an error event and bail immediately. LOGGER.debug("Fatal script error: %s" % e) self.on_event.send( ScriptRunnerEvent.SCRIPT_STOPPED_WITH_COMPILE_ERROR, exception=e) return # If we get here, we've successfully compiled our script. The next step # is to run it. Errors thrown during execution will be shown to the # user as ExceptionElements. # Update the Widget object with the new widget_states. # (The ReportContext has a reference to this object, so we just update it in-place) if rerun_data.widget_states is not None: self._widgets.set_state(rerun_data.widget_states) if config.get_option("runner.installTracer"): self._install_tracer() # This will be set to a RerunData instance if our execution # is interrupted by a RerunException. rerun_with_data = None try: # Create fake module. This gives us a name global namespace to # execute the code in. module = _new_module("__main__") # Install the fake module as the __main__ module. This allows # the pickle module to work inside the user's code, since it now # can know the module where the pickled objects stem from. # IMPORTANT: This means we can't use "if __name__ == '__main__'" in # our code, as it will point to the wrong module!!! sys.modules["__main__"] = module # Add special variables to the module's globals dict. # Note: The following is a requirement for the CodeHasher to # work correctly. The CodeHasher is scoped to # files contained in the directory of __main__.__file__, which we # assume is the main script directory. module.__dict__["__file__"] = self._report.script_path with modified_sys_path(self._report), self._set_execing_flag(): exec(code, module.__dict__) except RerunException as e: rerun_with_data = e.rerun_data except StopException: pass except BaseException as e: handle_uncaught_app_exception(e) finally: self._on_script_finished(ctx) # Use _log_if_error() to make sure we never ever ever stop running the # script without meaning to. _log_if_error(_clean_problem_modules) if rerun_with_data is not None: self._run_script(rerun_with_data) def _on_script_finished(self, ctx: ReportContext) -> None: """Called when our script finishes executing, even if it finished early with an exception. We perform post-run cleanup here. """ self._widgets.reset_triggers() self._widgets.cull_nonexistent(ctx.widget_ids_this_run.items()) # Signal that the script has finished. (We use SCRIPT_STOPPED_WITH_SUCCESS # even if we were stopped with an exception.) self.on_event.send(ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS) # Delete expired files now that the script has run and files in use # are marked as active. media_file_manager.del_expired_files()
def test_rerun_data_coalescing(self): """Test that multiple RERUN requests get coalesced with expected values. (This is similar to widgets_test.test_coalesce_widget_states - it's testing the same thing, but through the ScriptEventQueue interface.) """ queue = ScriptRequestQueue() states = WidgetStates() _create_widget("trigger", states).trigger_value = True _create_widget("int", states).int_value = 123 queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=states)) states = WidgetStates() _create_widget("trigger", states).trigger_value = False _create_widget("int", states).int_value = 456 queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=states)) event, data = queue.dequeue() self.assertEqual(event, ScriptRequest.RERUN) widgets = WidgetStateManager() widgets.set_state(data.widget_states) # Coalesced triggers should be True if either the old or # new value was True self.assertEqual(True, widgets.get_widget_value("trigger")) # Other widgets should have their newest value self.assertEqual(456, widgets.get_widget_value("int")) # We should have no more events self.assertEqual((None, None), queue.dequeue(), "Expected empty event queue") # Test that we can coalesce if previous widget state is None queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=None)) queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=None)) states = WidgetStates() _create_widget("int", states).int_value = 789 queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=states)) event, data = queue.dequeue() widgets = WidgetStateManager() widgets.set_state(data.widget_states) self.assertEqual(event, ScriptRequest.RERUN) self.assertEqual(789, widgets.get_widget_value("int")) # We should have no more events self.assertEqual((None, None), queue.dequeue(), "Expected empty event queue") # Test that we can coalesce if our *new* widget state is None states = WidgetStates() _create_widget("int", states).int_value = 101112 queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=states)) queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=None)) event, data = queue.dequeue() widgets = WidgetStateManager() widgets.set_state(data.widget_states) self.assertEqual(event, ScriptRequest.RERUN) self.assertEqual(101112, widgets.get_widget_value("int")) # We should have no more events self.assertEqual((None, None), queue.dequeue(), "Expected empty event queue")