def __init__(self, ioloop, script_path, command_line): """Create the server. It won't be started yet. Parameters ---------- ioloop : tornado.ioloop.IOLoop script_path : str command_line : str """ if Server._singleton is not None: raise RuntimeError( "Server already initialized. Use .get_current() instead") Server._singleton = self _set_tornado_log_levels() self._ioloop = ioloop self._script_path = script_path self._command_line = command_line # Mapping of ReportSession.id -> SessionInfo. self._session_info_by_id = {} self._must_stop = threading.Event() self._state = None self._set_state(State.INITIAL) self._message_cache = ForwardMsgCache() self._uploaded_file_mgr = UploadedFileManager() self._uploaded_file_mgr.on_files_added.connect(self._on_file_uploaded) self._report = None # type: Optional[Report]
def test_enqueue_with_tracer(self, _1, _2, patched_config): """Make sure there is no lock contention when tracer is on. When the tracer is set up, we want maybe_handle_execution_control_request to be executed only once. There was a bug in the past where it was called twice: once from the tracer and once from the enqueue function. This caused a lock contention. """ def get_option(name): if name == "server.runOnSave": # Just to avoid starting the watcher for no reason. return False if name == "client.displayEnabled": return True if name == "runner.installTracer": return True raise RuntimeError("Unexpected argument to get_option: %s" % name) patched_config.get_option.side_effect = get_option rs = ReportSession(False, None, "", "", UploadedFileManager()) mock_script_runner = MagicMock() rs._scriptrunner = mock_script_runner rs.enqueue({"dontcare": 123}) func = mock_script_runner.maybe_handle_execution_control_request # In reality, outside of a testing environment func should be called # once. But in this test we're actually not installing a tracer here, # since Report is mocked. So the correct behavior here is for func to # never be called. If you ever see it being called once here it's # likely because there's a bug in the enqueue function (which should # skip func when installTracer is on). func.assert_not_called()
def test_enqueue_without_tracer(self, _1, _2, patched_config): """Make sure we try to handle execution control requests. """ def get_option(name): if name == "server.runOnSave": # Just to avoid starting the watcher for no reason. return False if name == "client.displayEnabled": return True if name == "runner.installTracer": return False raise RuntimeError("Unexpected argument to get_option: %s" % name) patched_config.get_option.side_effect = get_option rs = ReportSession(False, None, "", "", UploadedFileManager()) mock_script_runner = MagicMock() mock_script_runner._install_tracer = ScriptRunner._install_tracer rs._scriptrunner = mock_script_runner rs.enqueue({"dontcare": 123}) func = mock_script_runner.maybe_handle_execution_control_request # Expect func to be called only once, inside enqueue(). func.assert_called_once()
def __init__(self, ioloop, script_path, command_line): """Initialize the ReportSession. Parameters ---------- ioloop : tornado.ioloop.IOLoop The Tornado IOLoop that we're running within. script_path : str Path of the Python file from which this report is generated. command_line : str Command line as input by the user. """ # Each ReportSession gets a unique ID self.id = ReportSession._next_id ReportSession._next_id += 1 self._ioloop = ioloop self._report = Report(script_path, command_line) self._state = ReportSessionState.REPORT_NOT_RUNNING self._uploaded_file_mgr = UploadedFileManager() self._main_dg = DeltaGenerator(enqueue=self.enqueue, container=BlockPath.MAIN) self._sidebar_dg = DeltaGenerator( enqueue=self.enqueue, container=BlockPath.SIDEBAR ) self._widget_states = WidgetStates() self._local_sources_watcher = LocalSourcesWatcher( self._report, self._on_source_file_changed ) self._sent_initialize_message = False self._storage = None self._maybe_reuse_previous_run = False self._run_on_save = config.get_option("server.runOnSave") # The ScriptRequestQueue is the means by which we communicate # with the active ScriptRunner. self._script_request_queue = ScriptRequestQueue() self._scriptrunner = None LOGGER.debug("ReportSession initialized (id=%s)", self.id)
def setUp(self, override_root=True): self.report_queue = ReportQueue() self.override_root = override_root self.orig_report_ctx = None if self.override_root: self.orig_report_ctx = get_report_ctx() add_report_ctx( threading.current_thread(), ReportContext( enqueue=self.report_queue.enqueue, widgets=Widgets(), widget_ids_this_run=_WidgetIDSet(), uploaded_file_mgr=UploadedFileManager(), ), )
def test_handle_save_request(self, _1): """Test that handle_save_request serializes files correctly.""" # Create a ReportSession with some mocked bits rs = ReportSession(False, self.io_loop, "mock_report.py", "", UploadedFileManager()) rs._report.report_id = "TestReportID" orig_ctx = get_report_ctx() ctx = ReportContext("TestSessionID", rs._report.enqueue, None, None, None) add_report_ctx(ctx=ctx) rs._scriptrunner = MagicMock() storage = MockStorage() rs._storage = storage # Send two deltas: empty and markdown st.empty() st.markdown("Text!") yield rs.handle_save_request(_create_mock_websocket()) # Check the order of the received files. Manifest should be last. self.assertEqual(3, len(storage.files)) self.assertEqual("reports/TestReportID/0.pb", storage.get_filename(0)) self.assertEqual("reports/TestReportID/1.pb", storage.get_filename(1)) self.assertEqual("reports/TestReportID/manifest.pb", storage.get_filename(2)) # Check the manifest manifest = storage.get_message(2, StaticManifest) self.assertEqual("mock_report", manifest.name) self.assertEqual(2, manifest.num_messages) self.assertEqual(StaticManifest.DONE, manifest.server_status) # Check that the deltas we sent match messages in storage sent_messages = rs._report._master_queue._queue received_messages = [ storage.get_message(0, ForwardMsg), storage.get_message(1, ForwardMsg), ] self.assertEqual(sent_messages, received_messages) add_report_ctx(ctx=orig_ctx)
def setUp(self, override_root=True): self.report_queue = ReportQueue() if override_root: main_dg = self.new_delta_generator() sidebar_dg = self.new_delta_generator(container=BlockPath.SIDEBAR) setattr( threading.current_thread(), REPORT_CONTEXT_ATTR_NAME, ReportContext( main_dg=main_dg, sidebar_dg=sidebar_dg, widgets=Widgets(), widget_ids_this_run=_WidgetIDSet(), uploaded_file_mgr=UploadedFileManager(), ), )
def setUp(self): self.mgr = UploadedFileManager() self.filemgr_events = [] self.mgr.on_files_added.connect(self._on_files_added)
class UploadedFileManagerTest(unittest.TestCase): def setUp(self): self.mgr = UploadedFileManager() self.filemgr_events = [] self.mgr.on_files_added.connect(self._on_files_added) def _on_files_added(self, file_list, **kwargs): self.filemgr_events.append(file_list) def test_add_file(self): self.assertIsNone(self.mgr.get_files("non-report", "non-widget")) event1 = UploadedFileList("session", "widget", [file1]) event2 = UploadedFileList("session", "widget", [file2]) self.mgr.add_files("session", "widget", [file1]) self.assertEqual([file1], self.mgr.get_files("session", "widget")) self.assertEqual([event1], self.filemgr_events) # Add another file with the same ID self.mgr.add_files("session", "widget", [file2]) self.assertEqual([file2], self.mgr.get_files("session", "widget")) self.assertEqual([event1, event2], self.filemgr_events) def test_remove_file(self): # This should not error. self.mgr.remove_files("non-report", "non-widget") self.mgr.add_files("session", "widget", [file1]) self.assertEqual([file1], self.mgr.get_files("session", "widget")) self.mgr.remove_files("session", "widget") self.assertIsNone(self.mgr.get_files("session", "widget")) def test_remove_all_files(self): # This should not error. self.mgr.remove_session_files("non-report") # Add two files with different session IDs, but the same widget ID. self.mgr.add_files("session1", "widget", [file1]) self.mgr.add_files("session2", "widget", [file1]) event1 = UploadedFileList("session1", "widget", [file1]) event2 = UploadedFileList("session2", "widget", [file1]) self.mgr.remove_session_files("session1") self.assertIsNone(self.mgr.get_files("session1", "widget")) self.assertEqual([file1], self.mgr.get_files("session2", "widget")) self.assertEqual([event1, event2], self.filemgr_events)
class Server(object): _singleton = None # type: Optional[Server] @classmethod def get_current(cls): """ Returns ------- Server The singleton Server object. """ if cls._singleton is None: raise RuntimeError("Server has not been initialized yet") return Server._singleton def __init__(self, ioloop, script_path, command_line): """Create the server. It won't be started yet. Parameters ---------- ioloop : tornado.ioloop.IOLoop script_path : str command_line : str """ if Server._singleton is not None: raise RuntimeError( "Server already initialized. Use .get_current() instead") Server._singleton = self _set_tornado_log_levels() self._ioloop = ioloop self._script_path = script_path self._command_line = command_line # Mapping of ReportSession.id -> SessionInfo. self._session_info_by_id = {} self._must_stop = threading.Event() self._state = None self._set_state(State.INITIAL) self._message_cache = ForwardMsgCache() self._uploaded_file_mgr = UploadedFileManager() self._uploaded_file_mgr.on_files_added.connect(self._on_file_uploaded) self._report = None # type: Optional[Report] def _on_file_uploaded(self, file): """Event handler for UploadedFileManager.on_file_added. When a file is uploaded by a user, schedule a re-run of the corresponding ReportSession. Parameters ---------- file : File The file that was just uploaded. """ session_info = self._get_session_info(file.session_id) if session_info is not None: session_info.session.request_rerun() else: # If an uploaded file doesn't belong to an existing session, # remove it so it doesn't stick around forever. self._uploaded_file_mgr.remove_files(file.session_id, file.widget_id) def _get_session_info(self, session_id): """Return the SessionInfo with the given id, or None if no such session exists. Parameters ---------- session_id : str Returns ------- SessionInfo or None """ return self._session_info_by_id.get(session_id, None) def start(self, on_started): """Start the server. Parameters ---------- on_started : callable A callback that will be called when the server's run-loop has started, and the server is ready to begin receiving clients. """ if self._state != State.INITIAL: raise RuntimeError("Server has already been started") LOGGER.debug("Starting server...") app = self._create_app() start_listening(app) port = config.get_option("server.port") LOGGER.debug("Server started on port %s", port) self._ioloop.spawn_callback(self._loop_coroutine, on_started) def get_debug(self) -> Dict[str, Dict[str, Any]]: if self._report: return {"report": self._report.get_debug()} return {} def _create_app(self): """Create our tornado web app. Returns ------- tornado.web.Application """ base = config.get_option("server.baseUrlPath") routes = [ ( make_url_path_regex(base, "stream"), _BrowserWebSocketHandler, dict(server=self), ), ( make_url_path_regex(base, "healthz"), HealthHandler, dict(callback=lambda: self.is_ready_for_browser_connection), ), (make_url_path_regex(base, "debugz"), DebugHandler, dict(server=self)), (make_url_path_regex(base, "metrics"), MetricsHandler), ( make_url_path_regex(base, "message"), MessageCacheHandler, dict(cache=self._message_cache), ), ( make_url_path_regex(base, "upload_file"), UploadFileRequestHandler, dict(file_mgr=self._uploaded_file_mgr), ), (make_url_path_regex(base, "media/(.*)"), MediaFileHandler), ] if config.get_option("global.developmentMode") and config.get_option( "global.useNode"): LOGGER.debug("Serving static content from the Node dev server") else: static_path = file_util.get_static_dir() LOGGER.debug("Serving static content from %s", static_path) routes.extend([ ( make_url_path_regex(base, "(.*)"), StaticFileHandler, { "path": "%s/" % static_path, "default_filename": "index.html" }, ), (make_url_path_regex(base, trailing_slash=False), AddSlashHandler), ]) return tornado.web.Application(routes, **TORNADO_SETTINGS) def _set_state(self, new_state): LOGGER.debug("Server state: %s -> %s" % (self._state, new_state)) self._state = new_state @property def is_ready_for_browser_connection(self): return self._state not in (State.INITIAL, State.STOPPING, State.STOPPED) @property def browser_is_connected(self): return self._state == State.ONE_OR_MORE_BROWSERS_CONNECTED @tornado.gen.coroutine def _loop_coroutine(self, on_started=None): try: if self._state == State.INITIAL: self._set_state(State.WAITING_FOR_FIRST_BROWSER) elif self._state == State.ONE_OR_MORE_BROWSERS_CONNECTED: pass else: raise RuntimeError("Bad server state at start: %s" % self._state) if on_started is not None: on_started(self) while not self._must_stop.is_set(): if self._state == State.WAITING_FOR_FIRST_BROWSER: pass elif self._state == State.ONE_OR_MORE_BROWSERS_CONNECTED: # Shallow-clone our sessions into a list, so we can iterate # over it and not worry about whether it's being changed # outside this coroutine. session_infos = list(self._session_info_by_id.values()) for session_info in session_infos: if session_info.ws is None: # Preheated. continue msg_list = session_info.session.flush_browser_queue() for msg in msg_list: try: self._send_message(session_info, msg) except tornado.websocket.WebSocketClosedError: self._close_report_session( session_info.session.id) yield yield elif self._state == State.NO_BROWSERS_CONNECTED: pass else: # Break out of the thread loop if we encounter any other state. break yield tornado.gen.sleep(0.01) # Shut down all ReportSessions for session_info in list(self._session_info_by_id.values()): session_info.session.shutdown() self._set_state(State.STOPPED) except Exception as e: print("EXCEPTION!", e) traceback.print_stack(file=sys.stdout) LOGGER.info(""" Please report this bug at https://github.com/streamlit/streamlit/issues. """) finally: self._on_stopped() def _send_message(self, session_info, msg): """Send a message to a client. If the client is likely to have already cached the message, we may instead send a "reference" message that contains only the hash of the message. Parameters ---------- session_info : SessionInfo The SessionInfo associated with websocket msg : ForwardMsg The message to send to the client """ msg.metadata.cacheable = is_cacheable_msg(msg) msg_to_send = msg if msg.metadata.cacheable: populate_hash_if_needed(msg) if self._message_cache.has_message_reference( msg, session_info.session, session_info.report_run_count): # This session has probably cached this message. Send # a reference instead. LOGGER.debug("Sending cached message ref (hash=%s)" % msg.hash) msg_to_send = create_reference_msg(msg) # Cache the message so it can be referenced in the future. # If the message is already cached, this will reset its # age. LOGGER.debug("Caching message (hash=%s)" % msg.hash) self._message_cache.add_message(msg, session_info.session, session_info.report_run_count) # If this was a `report_finished` message, we increment the # report_run_count for this session, and update the cache if (msg.WhichOneof("type") == "report_finished" and msg.report_finished == ForwardMsg.FINISHED_SUCCESSFULLY): LOGGER.debug( "Report finished successfully; " "removing expired entries from MessageCache " "(max_age=%s)", config.get_option("global.maxCachedMessageAge"), ) session_info.report_run_count += 1 self._message_cache.remove_expired_session_entries( session_info.session, session_info.report_run_count) # Ship it off! session_info.ws.write_message(serialize_forward_msg(msg_to_send), binary=True) def stop(self): click.secho(" Stopping...", fg="blue") self._set_state(State.STOPPING) self._must_stop.set() def _on_stopped(self): """Called when our runloop is exiting, to shut down the ioloop. This will end our process. (Tests can patch this method out, to prevent the test's ioloop from being shutdown.) """ self._ioloop.stop() def add_preheated_report_session(self): """Register a fake browser with the server and run the script. This is used to start running the user's script even before the first browser connects. """ session = self._create_report_session(ws=None) session.handle_rerun_script_request(is_preheat=True) def _create_report_session(self, ws): """Register a connected browser with the server. Parameters ---------- ws : _BrowserWebSocketHandler or None The newly-connected websocket handler or None if preheated connection. Returns ------- ReportSession The newly-created ReportSession for this browser connection. """ if PREHEATED_ID in self._session_info_by_id: assert len(self._session_info_by_id) == 1 LOGGER.debug("Reusing preheated context for ws %s", ws) session = self._session_info_by_id[PREHEATED_ID].session del self._session_info_by_id[PREHEATED_ID] session.id = 0 else: LOGGER.debug("Creating new context for ws %s", ws) session = ReportSession( is_preheat=(ws is None), ioloop=self._ioloop, script_path=self._script_path, command_line=self._command_line, uploaded_file_manager=self._uploaded_file_mgr, ) assert session.id not in self._session_info_by_id, ( "session.id '%s' registered multiple times!" % session.id) self._session_info_by_id[session.id] = SessionInfo(ws, session) if ws is not None: self._set_state(State.ONE_OR_MORE_BROWSERS_CONNECTED) return session def _close_report_session(self, session_id): """Shutdown and remove a ReportSession. This function may be called multiple times for the same session, which is not an error. (Subsequent calls just no-op.) Parameters ---------- session_id : str The ReportSession's id string. """ if session_id in self._session_info_by_id: session_info = self._session_info_by_id[session_id] del self._session_info_by_id[session_id] session_info.session.shutdown() if len(self._session_info_by_id) == 0: self._set_state(State.NO_BROWSERS_CONNECTED)
def get_app(self): self.file_mgr = UploadedFileManager() return tornado.web.Application([("/upload_file", UploadFileRequestHandler, dict(file_mgr=self.file_mgr))])
class UploadFileRequestHandlerTest(tornado.testing.AsyncHTTPTestCase): """Tests the /upload_file endpoint.""" def get_app(self): self.file_mgr = UploadedFileManager() return tornado.web.Application([("/upload_file", UploadFileRequestHandler, dict(file_mgr=self.file_mgr))]) def _upload_files(self, params): # We use requests.Request to construct our multipart/form-data request # here, because they are absurdly fiddly to compose, and Tornado # doesn't include a utility for building them. We then use self.fetch() # to actually send the request to the test server. req = requests.Request(method="POST", url=self.get_url("/upload_file"), files=params).prepare() return self.fetch("/upload_file", method=req.method, headers=req.headers, body=req.body) def test_upload_one_file(self): """Uploading a file should populate our file_mgr.""" file = UploadedFile("image.png", b"123") params = { file.name: file, "sessionId": (None, "fooReport"), "widgetId": (None, "barWidget"), } response = self._upload_files(params) self.assertEqual(200, response.code) self.assertEqual([file], self.file_mgr.get_files("fooReport", "barWidget")) def test_upload_multiple_files(self): file1 = UploadedFile("image1.png", b"123") file2 = UploadedFile("image2.png", b"456") file3 = UploadedFile("image3.png", b"789") params = { file1.name: file1, file2.name: file2, file3.name: file3, "sessionId": (None, "fooReport"), "widgetId": (None, "barWidget"), } response = self._upload_files(params) self.assertEqual(200, response.code) self.assertEqual( sorted([file1, file2, file3], key=_get_filename), sorted(self.file_mgr.get_files("fooReport", "barWidget"), key=_get_filename), ) def test_missing_params(self): """Missing params in the body should fail with 400 status.""" params = { "image.png": ("image.png", b"1234"), "sessionId": (None, "fooReport"), # "widgetId": (None, 'barWidget'), } response = self._upload_files(params) self.assertEqual(400, response.code) self.assertIn("Missing 'widgetId'", response.reason) def test_missing_file(self): """Missing file should fail with 400 status.""" params = { # "image.png": ("image.png", b"1234"), "sessionId": (None, "fooReport"), "widgetId": (None, "barWidget"), } response = self._upload_files(params) self.assertEqual(400, response.code) self.assertIn("Expected at least 1 file, but got 0", response.reason)
def test_msg_hash(self): """Test that ForwardMsg hash generation works as expected""" widget_idA = "A0123456789" widget_idB = "B0123456789" file_name = "example_file.png" file_bytes = bytearray( "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789", "utf-8", ) uploaded_file_mgr = UploadedFileManager() uploaded_file_mgr.create_or_clear_file(widget_idA, file_name, len(file_bytes), date.today(), 1) uploaded_file_mgr.create_or_clear_file(widget_idB, file_name, len(file_bytes), date.today(), 2) progress_a = uploaded_file_mgr.process_chunk(widget_idA, 0, file_bytes) self.assertEqual(progress_a, 1) progress_b = uploaded_file_mgr.process_chunk(widget_idB, 0, file_bytes[0:50]) self.assertEqual(progress_b, 0.5) progress_b = uploaded_file_mgr.process_chunk(widget_idB, 1, file_bytes[50:100]) self.assertEqual(progress_b, 1) progress_a, data_a = uploaded_file_mgr.get_data(widget_idA) progress_b, data_b = uploaded_file_mgr.get_data(widget_idB) self.assertEqual(progress_a, 100) self.assertEqual(progress_b, 100) self.assertEqual(len(data_a), len(file_bytes)) self.assertEqual(data_a, file_bytes) self.assertEqual(data_a, data_b) uploaded_file_mgr.delete_file(widget_idA) progress_a, data_a = uploaded_file_mgr.get_data(widget_idA) self.assertEqual(progress_a, 0) self.assertEqual(data_a, None) uploaded_file_mgr.delete_all_files() progress_b, data_b = uploaded_file_mgr.get_data(widget_idB) self.assertEqual(progress_b, 0) self.assertEqual(data_b, None)
class ReportSession(object): """ Contains session data for a single "user" of an active report (that is, a connected browser tab). Each ReportSession has its own Report, root DeltaGenerator, ScriptRunner, and widget state. A ReportSession is attached to each thread involved in running its Report. """ _next_id = 0 def __init__(self, ioloop, script_path, command_line): """Initialize the ReportSession. Parameters ---------- ioloop : tornado.ioloop.IOLoop The Tornado IOLoop that we're running within. script_path : str Path of the Python file from which this report is generated. command_line : str Command line as input by the user. """ # Each ReportSession gets a unique ID self.id = ReportSession._next_id ReportSession._next_id += 1 self._ioloop = ioloop self._report = Report(script_path, command_line) self._state = ReportSessionState.REPORT_NOT_RUNNING self._uploaded_file_mgr = UploadedFileManager() self._widget_states = WidgetStates() self._local_sources_watcher = LocalSourcesWatcher( self._report, self._on_source_file_changed) self._sent_initialize_message = False self._storage = None self._maybe_reuse_previous_run = False self._run_on_save = config.get_option("server.runOnSave") # The ScriptRequestQueue is the means by which we communicate # with the active ScriptRunner. self._script_request_queue = ScriptRequestQueue() self._scriptrunner = None LOGGER.debug("ReportSession initialized (id=%s)", self.id) def flush_browser_queue(self): """Clears the report queue and returns the messages it contained. The Server calls this periodically to deliver new messages to the browser connected to this report. Returns ------- list[ForwardMsg] The messages that were removed from the queue and should be delivered to the browser. """ return self._report.flush_browser_queue() def shutdown(self): """Shuts down the ReportSession. It's an error to use a ReportSession after it's been shut down. """ if self._state != ReportSessionState.SHUTDOWN_REQUESTED: LOGGER.debug("Shutting down (id=%s)", self.id) self._uploaded_file_mgr.delete_all_files() # Shut down the ScriptRunner, if one is active. # self._state must not be set to SHUTDOWN_REQUESTED until # after this is called. if self._scriptrunner is not None: self._enqueue_script_request(ScriptRequest.SHUTDOWN) self._state = ReportSessionState.SHUTDOWN_REQUESTED self._local_sources_watcher.close() def enqueue(self, msg): """Enqueues a new ForwardMsg to our browser queue. This can be called on both the main thread and a ScriptRunner run thread. Parameters ---------- msg : ForwardMsg The message to enqueue """ if not config.get_option("client.displayEnabled"): return # Avoid having two maybe_handle_execution_control_request running on # top of each other when tracer is installed. This leads to a lock # contention. if not config.get_option("runner.installTracer"): # If we have an active ScriptRunner, signal that it can handle an # execution control request. (Copy the scriptrunner reference to # avoid it being unset from underneath us, as this function can be # called outside the main thread.) scriptrunner = self._scriptrunner if scriptrunner is not None: scriptrunner.maybe_handle_execution_control_request() self._report.enqueue(msg) def enqueue_exception(self, e): """Enqueues an Exception message. Parameters ---------- e : BaseException """ # This does a few things: # 1) Clears the current report in the browser. # 2) Marks the current report as "stopped" in the browser. # 3) HACK: Resets any script params that may have been broken (e.g. the # command-line when rerunning with wrong argv[0]) self._on_scriptrunner_event( ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS) self._on_scriptrunner_event(ScriptRunnerEvent.SCRIPT_STARTED) self._on_scriptrunner_event( ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS) msg = ForwardMsg() msg.metadata.delta_id = 0 exception_proto.marshall(msg.delta.new_element.exception, e) self.enqueue(msg) def request_rerun(self, widget_state=None): """Signal that we're interested in running the script. If the script is not already running, it will be started immediately. Otherwise, a rerun will be requested. Parameters ---------- widget_state : dict | None The widget state dictionary to run the script with, or None to use the widget state from the previous run of the script. """ self._enqueue_script_request(ScriptRequest.RERUN, RerunData(widget_state)) def _on_source_file_changed(self): """One of our source files changed. Schedule a rerun if appropriate.""" if self._run_on_save: self.request_rerun() else: self._enqueue_file_change_message() def _clear_queue(self): self._report.clear() def _on_scriptrunner_event(self, event, exception=None, widget_states=None): """Called when our ScriptRunner emits an event. This is *not* called on the main thread. Parameters ---------- event : ScriptRunnerEvent exception : BaseException | None An exception thrown during compilation. Set only for the SCRIPT_STOPPED_WITH_COMPILE_ERROR event. widget_states : streamlit.proto.Widget_pb2.WidgetStates | None The ScriptRunner's final WidgetStates. Set only for the SHUTDOWN event. """ LOGGER.debug("OnScriptRunnerEvent: %s", event) prev_state = self._state if event == ScriptRunnerEvent.SCRIPT_STARTED: if self._state != ReportSessionState.SHUTDOWN_REQUESTED: self._state = ReportSessionState.REPORT_IS_RUNNING if config.get_option("server.liveSave"): # Enqueue into the IOLoop so it runs without blocking AND runs # on the main thread. self._ioloop.spawn_callback(self._save_running_report) self._clear_queue() self._maybe_enqueue_initialize_message() self._enqueue_new_report_message() elif (event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS or event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_COMPILE_ERROR): if self._state != ReportSessionState.SHUTDOWN_REQUESTED: self._state = ReportSessionState.REPORT_NOT_RUNNING script_succeeded = event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS self._enqueue_report_finished_message( ForwardMsg.FINISHED_SUCCESSFULLY if script_succeeded else ForwardMsg.FINISHED_WITH_COMPILE_ERROR) if config.get_option("server.liveSave"): # Enqueue into the IOLoop so it runs without blocking AND runs # on the main thread. self._ioloop.spawn_callback(self._save_final_report_and_quit) if script_succeeded: # When a script completes successfully, we update our # LocalSourcesWatcher to account for any source code changes # that change which modules should be watched. (This is run on # the main thread, because LocalSourcesWatcher is not # thread safe.) self._ioloop.spawn_callback( self._local_sources_watcher.update_watched_modules) else: # When a script fails to compile, we send along the exception. from streamlit.elements import exception_proto msg = ForwardMsg() exception_proto.marshall( msg.session_event.script_compilation_exception, exception) self.enqueue(msg) elif event == ScriptRunnerEvent.SHUTDOWN: # When ScriptRunner shuts down, update our local reference to it, # and check to see if we need to spawn a new one. (This is run on # the main thread.) def on_shutdown(): self._widget_states = widget_states self._scriptrunner = None # Because a new ScriptEvent could have been enqueued while the # scriptrunner was shutting down, we check to see if we should # create a new one. (Otherwise, a newly-enqueued ScriptEvent # won't be processed until another event is enqueued.) self._maybe_create_scriptrunner() self._ioloop.spawn_callback(on_shutdown) # Send a message if our run state changed report_was_running = prev_state == ReportSessionState.REPORT_IS_RUNNING report_is_running = self._state == ReportSessionState.REPORT_IS_RUNNING if report_is_running != report_was_running: self._enqueue_session_state_changed_message() def _enqueue_session_state_changed_message(self): msg = ForwardMsg() msg.session_state_changed.run_on_save = self._run_on_save msg.session_state_changed.report_is_running = ( self._state == ReportSessionState.REPORT_IS_RUNNING) self.enqueue(msg) def _enqueue_file_change_message(self): LOGGER.debug("Enqueuing report_changed message (id=%s)", self.id) msg = ForwardMsg() msg.session_event.report_changed_on_disk = True self.enqueue(msg) def _maybe_enqueue_initialize_message(self): if self._sent_initialize_message: return self._sent_initialize_message = True msg = ForwardMsg() imsg = msg.initialize imsg.config.sharing_enabled = config.get_option( "global.sharingMode") != "off" imsg.config.gather_usage_stats = config.get_option( "browser.gatherUsageStats") imsg.config.max_cached_message_age = config.get_option( "global.maxCachedMessageAge") imsg.config.mapbox_token = config.get_option("mapbox.token") LOGGER.debug( "New browser connection: " "gather_usage_stats=%s, " "sharing_enabled=%s, " "max_cached_message_age=%s", imsg.config.gather_usage_stats, imsg.config.sharing_enabled, imsg.config.max_cached_message_age, ) imsg.environment_info.streamlit_version = __version__ imsg.environment_info.python_version = ".".join( map(str, sys.version_info)) imsg.session_state.run_on_save = self._run_on_save imsg.session_state.report_is_running = ( self._state == ReportSessionState.REPORT_IS_RUNNING) imsg.user_info.installation_id = __installation_id__ if Credentials.get_current().activation: imsg.user_info.email = Credentials.get_current().activation.email else: imsg.user_info.email = "" imsg.command_line = self._report.command_line self.enqueue(msg) def _enqueue_new_report_message(self): self._report.generate_new_id() msg = ForwardMsg() msg.new_report.id = self._report.report_id msg.new_report.name = self._report.name msg.new_report.script_path = self._report.script_path self.enqueue(msg) def _enqueue_report_finished_message(self, status): """Enqueues a report_finished ForwardMsg. Parameters ---------- status : ReportFinishedStatus """ msg = ForwardMsg() msg.report_finished = status self.enqueue(msg) def handle_rerun_script_request(self, command_line=None, widget_state=None, is_preheat=False): """Tells the ScriptRunner to re-run its report. Parameters ---------- command_line : str | None The new command line arguments to run the script with, or None to use its previous command line value. widget_state : WidgetStates | None The WidgetStates protobuf to run the script with, or None to use its previous widget states. is_preheat: boolean True if this ReportSession should run the script immediately, and then ignore the next rerun request if it matches the already-ran widget state. """ if is_preheat: self._maybe_reuse_previous_run = True # For next time. elif self._maybe_reuse_previous_run: # If this is a "preheated" ReportSession, reuse the previous run if # the widget state matches. But only do this one time ever. self._maybe_reuse_previous_run = False has_widget_state = (widget_state is not None and len(widget_state.widgets) > 0) if not has_widget_state: LOGGER.debug( "Skipping rerun since the preheated run is the same") return self.request_rerun(widget_state) def handle_upload_file(self, upload_file): self._uploaded_file_mgr.create_or_clear_file( widget_id=upload_file.widget_id, name=upload_file.name, size=upload_file.size, last_modified=upload_file.lastModified, chunks=upload_file.chunks, ) self.handle_rerun_script_request(widget_state=self._widget_states) def handle_upload_file_chunk(self, upload_file_chunk): progress = self._uploaded_file_mgr.process_chunk( widget_id=upload_file_chunk.widget_id, index=upload_file_chunk.index, data=upload_file_chunk.data, ) if progress == 1: self.handle_rerun_script_request(widget_state=self._widget_states) def handle_delete_uploaded_file(self, delete_uploaded_file): self._uploaded_file_mgr.delete_file( widget_id=delete_uploaded_file.widget_id) self.handle_rerun_script_request(widget_state=self._widget_states) def handle_stop_script_request(self): """Tells the ScriptRunner to stop running its report.""" self._enqueue_script_request(ScriptRequest.STOP) def handle_clear_cache_request(self): """Clears this report's cache. Because this cache is global, it will be cleared for all users. """ # Setting verbose=True causes clear_cache to print to stdout. # Since this command was initiated from the browser, the user # doesn't need to see the results of the command in their # terminal. caching.clear_cache() def handle_set_run_on_save_request(self, new_value): """Changes our run_on_save flag to the given value. The browser will be notified of the change. Parameters ---------- new_value : bool New run_on_save value """ self._run_on_save = new_value self._enqueue_session_state_changed_message() def _enqueue_script_request(self, request, data=None): """Enqueue a ScriptEvent into our ScriptEventQueue. If a script thread is not already running, one will be created to handle the event. Parameters ---------- request : ScriptRequest The type of request. data : Any Data associated with the request, if any. """ if self._state == ReportSessionState.SHUTDOWN_REQUESTED: LOGGER.warning("Discarding %s request after shutdown" % request) return self._script_request_queue.enqueue(request, data) self._maybe_create_scriptrunner() def _maybe_create_scriptrunner(self): """Create a new ScriptRunner if we have unprocessed script requests. This is called every time a ScriptRequest is enqueued, and also after a ScriptRunner shuts down, in case new requests were enqueued during its termination. This function should only be called on the main thread. """ if (self._state == ReportSessionState.SHUTDOWN_REQUESTED or self._scriptrunner is not None or not self._script_request_queue.has_request): return # Create the ScriptRunner, attach event handlers, and start it self._scriptrunner = ScriptRunner( report=self._report, enqueue_forward_msg=self.enqueue, widget_states=self._widget_states, request_queue=self._script_request_queue, uploaded_file_mgr=self._uploaded_file_mgr, ) self._scriptrunner.on_event.connect(self._on_scriptrunner_event) self._scriptrunner.start() @tornado.gen.coroutine def handle_save_request(self, ws): """Save serialized version of report deltas to the cloud. "Progress" ForwardMsgs will be sent to the client during the upload. These messages are sent "out of band" - that is, they don't get enqueued into the ReportQueue (because they're not part of the report). Instead, they're written directly to the report's WebSocket. Parameters ---------- ws : _BrowserWebSocketHandler The report's websocket handler. """ @tornado.gen.coroutine def progress(percent): progress_msg = ForwardMsg() progress_msg.upload_report_progress = percent yield ws.write_message(serialize_forward_msg(progress_msg), binary=True) # Indicate that the save is starting. try: yield progress(0) url = yield self._save_final_report(progress) # Indicate that the save is done. progress_msg = ForwardMsg() progress_msg.report_uploaded = url yield ws.write_message(serialize_forward_msg(progress_msg), binary=True) except Exception as e: # Horrible hack to show something if something breaks. err_msg = "%s: %s" % (type(e).__name__, str(e) or "No further details.") progress_msg = ForwardMsg() progress_msg.report_uploaded = err_msg yield ws.write_message(serialize_forward_msg(progress_msg), binary=True) LOGGER.warning("Failed to save report:", exc_info=e) @tornado.gen.coroutine def _save_running_report(self): files = self._report.serialize_running_report_to_files() url = yield self._get_storage().save_report_files( self._report.report_id, files) if config.get_option("server.liveSave"): url_util.print_url("Saved running app", url) raise tornado.gen.Return(url) @tornado.gen.coroutine def _save_final_report(self, progress_coroutine=None): files = self._report.serialize_final_report_to_files() url = yield self._get_storage().save_report_files( self._report.report_id, files, progress_coroutine) if config.get_option("server.liveSave"): url_util.print_url("Saved final app", url) raise tornado.gen.Return(url) @tornado.gen.coroutine def _save_final_report_and_quit(self): yield self._save_final_report() self._ioloop.stop() def _get_storage(self): if self._storage is None: sharing_mode = config.get_option("global.sharingMode") if sharing_mode == "s3": self._storage = S3Storage() elif sharing_mode == "file": self._storage = FileStorage() else: raise RuntimeError("Unsupported sharing mode '%s'" % sharing_mode) return self._storage