class UploadedFileManagerTest(unittest.TestCase): def setUp(self): self.mgr = UploadedFileManager() self.filemgr_events = [] self.mgr.on_files_added.connect(self._on_files_added) def _on_files_added(self, file_list, **kwargs): self.filemgr_events.append(file_list) def test_add_file(self): self.assertIsNone(self.mgr.get_files("non-report", "non-widget")) event1 = UploadedFileList("session", "widget", [file1]) event2 = UploadedFileList("session", "widget", [file2]) self.mgr.add_files("session", "widget", [file1]) self.assertEqual([file1], self.mgr.get_files("session", "widget")) self.assertEqual([event1], self.filemgr_events) # Add another file with the same ID self.mgr.add_files("session", "widget", [file2]) self.assertEqual([file2], self.mgr.get_files("session", "widget")) self.assertEqual([event1, event2], self.filemgr_events) def test_remove_file(self): # This should not error. self.mgr.remove_files("non-report", "non-widget") self.mgr.add_files("session", "widget", [file1]) self.assertEqual([file1], self.mgr.get_files("session", "widget")) self.mgr.remove_files("session", "widget") self.assertIsNone(self.mgr.get_files("session", "widget")) def test_remove_all_files(self): # This should not error. self.mgr.remove_session_files("non-report") # Add two files with different session IDs, but the same widget ID. self.mgr.add_files("session1", "widget", [file1]) self.mgr.add_files("session2", "widget", [file1]) event1 = UploadedFileList("session1", "widget", [file1]) event2 = UploadedFileList("session2", "widget", [file1]) self.mgr.remove_session_files("session1") self.assertIsNone(self.mgr.get_files("session1", "widget")) self.assertEqual([file1], self.mgr.get_files("session2", "widget")) self.assertEqual([event1, event2], self.filemgr_events)
class Server(object): _singleton = None # type: Optional[Server] @classmethod def get_current(cls): """ Returns ------- Server The singleton Server object. """ if cls._singleton is None: raise RuntimeError("Server has not been initialized yet") return Server._singleton def __init__(self, ioloop, script_path, command_line): """Create the server. It won't be started yet. Parameters ---------- ioloop : tornado.ioloop.IOLoop script_path : str command_line : str """ if Server._singleton is not None: raise RuntimeError( "Server already initialized. Use .get_current() instead") Server._singleton = self _set_tornado_log_levels() self._ioloop = ioloop self._script_path = script_path self._command_line = command_line # Mapping of ReportSession.id -> SessionInfo. self._session_info_by_id = {} self._must_stop = threading.Event() self._state = None self._set_state(State.INITIAL) self._message_cache = ForwardMsgCache() self._uploaded_file_mgr = UploadedFileManager() self._uploaded_file_mgr.on_files_added.connect(self._on_file_uploaded) self._report = None # type: Optional[Report] def _on_file_uploaded(self, file): """Event handler for UploadedFileManager.on_file_added. When a file is uploaded by a user, schedule a re-run of the corresponding ReportSession. Parameters ---------- file : File The file that was just uploaded. """ session_info = self._get_session_info(file.session_id) if session_info is not None: session_info.session.request_rerun() else: # If an uploaded file doesn't belong to an existing session, # remove it so it doesn't stick around forever. self._uploaded_file_mgr.remove_files(file.session_id, file.widget_id) def _get_session_info(self, session_id): """Return the SessionInfo with the given id, or None if no such session exists. Parameters ---------- session_id : str Returns ------- SessionInfo or None """ return self._session_info_by_id.get(session_id, None) def start(self, on_started): """Start the server. Parameters ---------- on_started : callable A callback that will be called when the server's run-loop has started, and the server is ready to begin receiving clients. """ if self._state != State.INITIAL: raise RuntimeError("Server has already been started") LOGGER.debug("Starting server...") app = self._create_app() start_listening(app) port = config.get_option("server.port") LOGGER.debug("Server started on port %s", port) self._ioloop.spawn_callback(self._loop_coroutine, on_started) def get_debug(self) -> Dict[str, Dict[str, Any]]: if self._report: return {"report": self._report.get_debug()} return {} def _create_app(self): """Create our tornado web app. Returns ------- tornado.web.Application """ base = config.get_option("server.baseUrlPath") routes = [ ( make_url_path_regex(base, "stream"), _BrowserWebSocketHandler, dict(server=self), ), ( make_url_path_regex(base, "healthz"), HealthHandler, dict(callback=lambda: self.is_ready_for_browser_connection), ), (make_url_path_regex(base, "debugz"), DebugHandler, dict(server=self)), (make_url_path_regex(base, "metrics"), MetricsHandler), ( make_url_path_regex(base, "message"), MessageCacheHandler, dict(cache=self._message_cache), ), ( make_url_path_regex(base, "upload_file"), UploadFileRequestHandler, dict(file_mgr=self._uploaded_file_mgr), ), (make_url_path_regex(base, "media/(.*)"), MediaFileHandler), ] if config.get_option("global.developmentMode") and config.get_option( "global.useNode"): LOGGER.debug("Serving static content from the Node dev server") else: static_path = file_util.get_static_dir() LOGGER.debug("Serving static content from %s", static_path) routes.extend([ ( make_url_path_regex(base, "(.*)"), StaticFileHandler, { "path": "%s/" % static_path, "default_filename": "index.html" }, ), (make_url_path_regex(base, trailing_slash=False), AddSlashHandler), ]) return tornado.web.Application(routes, **TORNADO_SETTINGS) def _set_state(self, new_state): LOGGER.debug("Server state: %s -> %s" % (self._state, new_state)) self._state = new_state @property def is_ready_for_browser_connection(self): return self._state not in (State.INITIAL, State.STOPPING, State.STOPPED) @property def browser_is_connected(self): return self._state == State.ONE_OR_MORE_BROWSERS_CONNECTED @tornado.gen.coroutine def _loop_coroutine(self, on_started=None): try: if self._state == State.INITIAL: self._set_state(State.WAITING_FOR_FIRST_BROWSER) elif self._state == State.ONE_OR_MORE_BROWSERS_CONNECTED: pass else: raise RuntimeError("Bad server state at start: %s" % self._state) if on_started is not None: on_started(self) while not self._must_stop.is_set(): if self._state == State.WAITING_FOR_FIRST_BROWSER: pass elif self._state == State.ONE_OR_MORE_BROWSERS_CONNECTED: # Shallow-clone our sessions into a list, so we can iterate # over it and not worry about whether it's being changed # outside this coroutine. session_infos = list(self._session_info_by_id.values()) for session_info in session_infos: if session_info.ws is None: # Preheated. continue msg_list = session_info.session.flush_browser_queue() for msg in msg_list: try: self._send_message(session_info, msg) except tornado.websocket.WebSocketClosedError: self._close_report_session( session_info.session.id) yield yield elif self._state == State.NO_BROWSERS_CONNECTED: pass else: # Break out of the thread loop if we encounter any other state. break yield tornado.gen.sleep(0.01) # Shut down all ReportSessions for session_info in list(self._session_info_by_id.values()): session_info.session.shutdown() self._set_state(State.STOPPED) except Exception as e: print("EXCEPTION!", e) traceback.print_stack(file=sys.stdout) LOGGER.info(""" Please report this bug at https://github.com/streamlit/streamlit/issues. """) finally: self._on_stopped() def _send_message(self, session_info, msg): """Send a message to a client. If the client is likely to have already cached the message, we may instead send a "reference" message that contains only the hash of the message. Parameters ---------- session_info : SessionInfo The SessionInfo associated with websocket msg : ForwardMsg The message to send to the client """ msg.metadata.cacheable = is_cacheable_msg(msg) msg_to_send = msg if msg.metadata.cacheable: populate_hash_if_needed(msg) if self._message_cache.has_message_reference( msg, session_info.session, session_info.report_run_count): # This session has probably cached this message. Send # a reference instead. LOGGER.debug("Sending cached message ref (hash=%s)" % msg.hash) msg_to_send = create_reference_msg(msg) # Cache the message so it can be referenced in the future. # If the message is already cached, this will reset its # age. LOGGER.debug("Caching message (hash=%s)" % msg.hash) self._message_cache.add_message(msg, session_info.session, session_info.report_run_count) # If this was a `report_finished` message, we increment the # report_run_count for this session, and update the cache if (msg.WhichOneof("type") == "report_finished" and msg.report_finished == ForwardMsg.FINISHED_SUCCESSFULLY): LOGGER.debug( "Report finished successfully; " "removing expired entries from MessageCache " "(max_age=%s)", config.get_option("global.maxCachedMessageAge"), ) session_info.report_run_count += 1 self._message_cache.remove_expired_session_entries( session_info.session, session_info.report_run_count) # Ship it off! session_info.ws.write_message(serialize_forward_msg(msg_to_send), binary=True) def stop(self): click.secho(" Stopping...", fg="blue") self._set_state(State.STOPPING) self._must_stop.set() def _on_stopped(self): """Called when our runloop is exiting, to shut down the ioloop. This will end our process. (Tests can patch this method out, to prevent the test's ioloop from being shutdown.) """ self._ioloop.stop() def add_preheated_report_session(self): """Register a fake browser with the server and run the script. This is used to start running the user's script even before the first browser connects. """ session = self._create_report_session(ws=None) session.handle_rerun_script_request(is_preheat=True) def _create_report_session(self, ws): """Register a connected browser with the server. Parameters ---------- ws : _BrowserWebSocketHandler or None The newly-connected websocket handler or None if preheated connection. Returns ------- ReportSession The newly-created ReportSession for this browser connection. """ if PREHEATED_ID in self._session_info_by_id: assert len(self._session_info_by_id) == 1 LOGGER.debug("Reusing preheated context for ws %s", ws) session = self._session_info_by_id[PREHEATED_ID].session del self._session_info_by_id[PREHEATED_ID] session.id = 0 else: LOGGER.debug("Creating new context for ws %s", ws) session = ReportSession( is_preheat=(ws is None), ioloop=self._ioloop, script_path=self._script_path, command_line=self._command_line, uploaded_file_manager=self._uploaded_file_mgr, ) assert session.id not in self._session_info_by_id, ( "session.id '%s' registered multiple times!" % session.id) self._session_info_by_id[session.id] = SessionInfo(ws, session) if ws is not None: self._set_state(State.ONE_OR_MORE_BROWSERS_CONNECTED) return session def _close_report_session(self, session_id): """Shutdown and remove a ReportSession. This function may be called multiple times for the same session, which is not an error. (Subsequent calls just no-op.) Parameters ---------- session_id : str The ReportSession's id string. """ if session_id in self._session_info_by_id: session_info = self._session_info_by_id[session_id] del self._session_info_by_id[session_id] session_info.session.shutdown() if len(self._session_info_by_id) == 0: self._set_state(State.NO_BROWSERS_CONNECTED)