def __init__(self, ioloop: tornado.ioloop.IOLoop, script_path: str, command_line: str): """Create the server. It won't be started yet.""" if Server._singleton is not None: raise RuntimeError( "Server already initialized. Use .get_current() instead") Server._singleton = self _set_tornado_log_levels() self._ioloop = ioloop self._script_path = script_path self._command_line = command_line # Mapping of ReportSession.id -> SessionInfo. self._session_info_by_id = {} # type: Dict[str, SessionInfo] self._must_stop = threading.Event() self._state = None self._set_state(State.INITIAL) self._message_cache = ForwardMsgCache() self._uploaded_file_mgr = UploadedFileManager() self._uploaded_file_mgr.on_files_updated.connect(self.on_files_updated) self._report = None # type: Optional[Report] self._preheated_session_id = None # type: Optional[str]
def get_app(self): self.file_mgr = UploadedFileManager() self._get_session_info = lambda x: True return tornado.web.Application([ ( UPLOAD_FILE_ROUTE, UploadFileRequestHandler, dict( file_mgr=self.file_mgr, get_session_info=self._get_session_info, ), ), ])
def test_handle_save_request(self, _1): """Test that handle_save_request serializes files correctly.""" # Create a ReportSession with some mocked bits rs = ReportSession( self.io_loop, "mock_report.py", "", UploadedFileManager(), None ) rs._report.report_id = "TestReportID" orig_ctx = get_report_ctx() ctx = ReportContext( "TestSessionID", rs._report.enqueue, "", SessionState(), UploadedFileManager(), ) add_report_ctx(ctx=ctx) rs._scriptrunner = MagicMock() storage = MockStorage() rs._storage = storage # Send two deltas: empty and markdown st.empty() st.markdown("Text!") yield rs.handle_save_request(_create_mock_websocket()) # Check the order of the received files. Manifest should be last. self.assertEqual(3, len(storage.files)) self.assertEqual("reports/TestReportID/0.pb", storage.get_filename(0)) self.assertEqual("reports/TestReportID/1.pb", storage.get_filename(1)) self.assertEqual("reports/TestReportID/manifest.pb", storage.get_filename(2)) # Check the manifest manifest = storage.get_message(2, StaticManifest) self.assertEqual("mock_report", manifest.name) self.assertEqual(2, manifest.num_messages) self.assertEqual(StaticManifest.DONE, manifest.server_status) # Check that the deltas we sent match messages in storage sent_messages = rs._report._master_queue._queue received_messages = [ storage.get_message(0, ForwardMsg), storage.get_message(1, ForwardMsg), ] self.assertEqual(sent_messages, received_messages) add_report_ctx(ctx=orig_ctx)
def get_app(self): self.file_mgr = UploadedFileManager() return tornado.web.Application([ ( "/upload_file/(.*)/(.*)/([0-9]*)?", UploadFileRequestHandler, dict(file_mgr=self.file_mgr), ), ( "/upload_file", UploadFileRequestHandler, dict(file_mgr=self.file_mgr), ), ])
def test_passes_client_state_on_run_on_save(self, _): rs = ReportSession(None, "", "", UploadedFileManager(), None) rs._run_on_save = True rs.request_rerun = MagicMock() rs._on_source_file_changed() rs.request_rerun.assert_called_once_with(rs._client_state)
def test_clear_cache_resets_session_state(self, _1): rs = AppSession( None, SessionData("", ""), UploadedFileManager(), None, MagicMock() ) rs._session_state["foo"] = "bar" rs.handle_clear_cache_request() self.assertTrue("foo" not in rs._session_state)
def __init__(self, script_name): """Initializes the ScriptRunner for the given script_name""" # DeltaGenerator deltas will be enqueued into self.forward_msg_queue. self.forward_msg_queue = ForwardMsgQueue() def enqueue_fn(msg): self.forward_msg_queue.enqueue(msg) self.maybe_handle_execution_control_request() self.script_request_queue = ScriptRequestQueue() script_path = os.path.join(os.path.dirname(__file__), "test_data", script_name) super(TestScriptRunner, self).__init__( session_id="test session id", session_data=SessionData(script_path, "test command line"), enqueue_forward_msg=enqueue_fn, client_state=ClientState(), session_state=SessionState(), request_queue=self.script_request_queue, uploaded_file_mgr=UploadedFileManager(), ) # Accumulates uncaught exceptions thrown by our run thread. self.script_thread_exceptions = [] # Accumulates all ScriptRunnerEvents emitted by us. self.events = [] self.event_data = [] def record_event(event, **kwargs): self.events.append(event) self.event_data.append(kwargs) self.on_event.connect(record_event, weak=False)
def test_enqueue_without_tracer(self, _1, _2, patched_config): """Make sure we try to handle execution control requests.""" def get_option(name): if name == "server.runOnSave": # Just to avoid starting the watcher for no reason. return False if name == "client.displayEnabled": return True if name == "runner.installTracer": return False raise RuntimeError("Unexpected argument to get_option: %s" % name) patched_config.get_option.side_effect = get_option rs = ReportSession(None, "", "", UploadedFileManager()) mock_script_runner = MagicMock() mock_script_runner._install_tracer = ScriptRunner._install_tracer rs._scriptrunner = mock_script_runner mock_msg = MagicMock() rs.enqueue(mock_msg) func = mock_script_runner.maybe_handle_execution_control_request # Expect func to be called only once, inside enqueue(). func.assert_called_once()
def test_clear_cache_all_caches( self, clear_singleton_cache, clear_memo_cache, clear_legacy_cache ): rs = ReportSession(MagicMock(), "", "", UploadedFileManager(), None) rs.handle_clear_cache_request() clear_singleton_cache.assert_called_once() clear_memo_cache.assert_called_once() clear_legacy_cache.assert_called_once()
def test_get_deploy_params_with_no_git(self, _1): """Make sure we try to handle execution control requests.""" import os os.environ["PATH"] = "" rs = ReportSession(None, report_session.__file__, "", UploadedFileManager()) self.assertIsNone(rs.get_deploy_params())
def test_enqueue_new_session_message(self, _1, _2, patched_config): def get_option(name): if name == "server.runOnSave": # Just to avoid starting the watcher for no reason. return False return config.get_option(name) patched_config.get_option.side_effect = get_option patched_config.get_options_for_section.side_effect = ( _mock_get_options_for_section() ) # Create a AppSession with some mocked bits rs = AppSession( self.io_loop, SessionData("mock_report.py", ""), UploadedFileManager(), lambda: None, MagicMock(), ) orig_ctx = get_script_run_ctx() ctx = ScriptRunContext( "TestSessionID", rs._session_data.enqueue, "", None, None ) add_script_run_ctx(ctx=ctx) rs._on_scriptrunner_event(ScriptRunnerEvent.SCRIPT_STARTED) sent_messages = rs._session_data._browser_queue._queue self.assertEqual(len(sent_messages), 2) # NewApp and SessionState messages # Note that we're purposefully not very thoroughly testing new_session # fields below to avoid getting to the point where we're just # duplicating code in tests. new_session_msg = sent_messages[0].new_session self.assertEqual("mock_scriptrun_id", new_session_msg.script_run_id) self.assertEqual(new_session_msg.HasField("config"), True) self.assertEqual( new_session_msg.config.allow_run_on_save, config.get_option("server.allowRunOnSave"), ) self.assertEqual(new_session_msg.HasField("custom_theme"), True) self.assertEqual(new_session_msg.custom_theme.text_color, "black") init_msg = new_session_msg.initialize self.assertEqual(init_msg.HasField("user_info"), True) add_script_run_ctx(ctx=orig_ctx)
def test_set_page_config_immutable(self): """st.set_page_config must be called at most once""" fake_enqueue = lambda msg: None ctx = ReportContext("TestSessionID", fake_enqueue, "", Widgets(), UploadedFileManager()) msg = ForwardMsg() msg.page_config_changed.title = "foo" ctx.enqueue(msg) with self.assertRaises(StreamlitAPIException): ctx.enqueue(msg)
def __init__(self, ioloop: IOLoop, script_path: str, command_line: Optional[str]): """Create the server. It won't be started yet.""" if Server._singleton is not None: raise RuntimeError( "Server already initialized. Use .get_current() instead") Server._singleton = self _set_tornado_log_levels() self._ioloop = ioloop self._script_path = script_path self._command_line = command_line # Mapping of ReportSession.id -> SessionInfo. self._session_info_by_id: Dict[str, SessionInfo] = {} self._must_stop = tornado.locks.Event() self._state = State.INITIAL self._message_cache = ForwardMsgCache() self._uploaded_file_mgr = UploadedFileManager() self._uploaded_file_mgr.on_files_updated.connect(self.on_files_updated) self._report: Optional[Report] = None self._preheated_session_id: Optional[str] = None self._has_connection = tornado.locks.Condition() self._need_send_data = tornado.locks.Event() # StatsManager self._stats_mgr = StatsManager() self._stats_mgr.register_provider(get_memo_stats_provider()) self._stats_mgr.register_provider(get_singleton_stats_provider()) self._stats_mgr.register_provider(_mem_caches) self._stats_mgr.register_provider(self._message_cache) self._stats_mgr.register_provider(in_memory_file_manager) self._stats_mgr.register_provider(self._uploaded_file_mgr) self._stats_mgr.register_provider( SessionStateStatProvider(self._session_info_by_id))
def test_set_page_config_reset(self): """st.set_page_config should be allowed after a rerun""" fake_enqueue = lambda msg: None ctx = ReportContext("TestSessionID", fake_enqueue, "", Widgets(), UploadedFileManager()) msg = ForwardMsg() msg.page_config_changed.title = "foo" ctx.enqueue(msg) ctx.reset() try: ctx.enqueue(msg) except StreamlitAPIException: self.fail("set_page_config should have succeeded after reset!")
def test_set_page_config_first(self): """st.set_page_config must be called before other st commands""" fake_enqueue = lambda msg: None ctx = ReportContext("TestSessionID", fake_enqueue, "", Widgets(), UploadedFileManager()) markdown_msg = ForwardMsg() markdown_msg.delta.new_element.markdown.body = "foo" msg = ForwardMsg() msg.page_config_changed.title = "foo" ctx.enqueue(markdown_msg) with self.assertRaises(StreamlitAPIException): ctx.enqueue(msg)
def setUp(self, override_root=True): self.report_queue = ReportQueue() self.override_root = override_root self.orig_report_ctx = None if self.override_root: self.orig_report_ctx = get_report_ctx() add_report_ctx( threading.current_thread(), ReportContext( session_id="test session id", enqueue=self.report_queue.enqueue, query_string="", widgets=Widgets(), uploaded_file_mgr=UploadedFileManager(), ), )
def setUp(self, override_root=True): self.forward_msg_queue = ForwardMsgQueue() self.override_root = override_root self.orig_report_ctx = None self.new_script_run_ctx = ScriptRunContext( session_id="test session id", enqueue=self.forward_msg_queue.enqueue, query_string="", session_state=SessionState(), uploaded_file_mgr=UploadedFileManager(), ) if self.override_root: self.orig_report_ctx = get_script_run_ctx() add_script_run_ctx(threading.current_thread(), self.new_script_run_ctx) self.app_session = FakeAppSession()
class UploadFileRequestHandlerInvalidSessionTest( tornado.testing.AsyncHTTPTestCase): """Tests the /upload_file endpoint.""" def get_app(self): self.file_mgr = UploadedFileManager() self._get_session_info = lambda x: None return tornado.web.Application([ ( UPLOAD_FILE_ROUTE, UploadFileRequestHandler, dict( file_mgr=self.file_mgr, get_session_info=self._get_session_info, ), ), ]) def _upload_files(self, params): # We use requests.Request to construct our multipart/form-data request # here, because they are absurdly fiddly to compose, and Tornado # doesn't include a utility for building them. We then use self.fetch() # to actually send the request to the test server. req = requests.Request(method="POST", url=self.get_url("/upload_file"), files=params).prepare() return self.fetch("/upload_file", method=req.method, headers=req.headers, body=req.body) def test_upload_one_file(self): """Upload should fail if the sessionId doesn't exist.""" file = MockFile("filename", b"123") params = { file.name: file.data, "sessionId": (None, "mockSessionId"), "widgetId": (None, "mockWidgetId"), } response = self._upload_files(params) self.assertEqual(400, response.code) self.assertIn("Invalid session_id: 'mockSessionId'", response.reason) self.assertEqual( self.file_mgr.get_all_files("mockSessionId", "mockWidgetId"), [])
def __init__(self, script_name: str): """Initializes the ScriptRunner for the given script_name""" # DeltaGenerator deltas will be enqueued into self.forward_msg_queue. self.forward_msg_queue = ForwardMsgQueue() main_script_path = os.path.join(os.path.dirname(__file__), "test_data", script_name) super().__init__( session_id="test session id", session_data=SessionData(main_script_path, "test command line"), client_state=ClientState(), session_state=SessionState(), uploaded_file_mgr=UploadedFileManager(), initial_rerun_data=RerunData(), ) # Accumulates uncaught exceptions thrown by our run thread. self.script_thread_exceptions: List[BaseException] = [] # Accumulates all ScriptRunnerEvents emitted by us. self.events: List[ScriptRunnerEvent] = [] self.event_data: List[Any] = [] def record_event(sender: Optional[ScriptRunner], event: ScriptRunnerEvent, **kwargs) -> None: # Assert that we're not getting unexpected `sender` params # from ScriptRunner.on_event assert (sender is None or sender == self), "Unexpected ScriptRunnerEvent sender!" self.events.append(event) self.event_data.append(kwargs) # Send ENQUEUE_FORWARD_MSGs to our queue if event == ScriptRunnerEvent.ENQUEUE_FORWARD_MSG: forward_msg = kwargs["forward_msg"] self.forward_msg_queue.enqueue(forward_msg) self.on_event.connect(record_event, weak=False)
def test_enqueue_with_tracer(self, _1, _2, patched_config, _4): """Make sure there is no lock contention when tracer is on. When the tracer is set up, we want maybe_handle_execution_control_request to be executed only once. There was a bug in the past where it was called twice: once from the tracer and once from the enqueue function. This caused a lock contention. """ def get_option(name): if name == "server.runOnSave": # Just to avoid starting the watcher for no reason. return False if name == "client.displayEnabled": return True if name == "runner.installTracer": return True raise RuntimeError("Unexpected argument to get_option: %s" % name) patched_config.get_option.side_effect = get_option rs = AppSession( None, SessionData("", ""), UploadedFileManager(), lambda: None, MagicMock() ) mock_script_runner = MagicMock() rs._scriptrunner = mock_script_runner mock_msg = MagicMock() rs.enqueue(mock_msg) func = mock_script_runner.maybe_handle_execution_control_request # In reality, outside of a testing environment func should be called # once. But in this test we're actually not installing a tracer here, # since SessionData is mocked. So the correct behavior here is for func to # never be called. If you ever see it being called once here it's # likely because there's a bug in the enqueue function (which should # skip func when installTracer is on). func.assert_not_called()
def test_disallow_set_page_config_twice(self): """st.set_page_config cannot be called twice""" fake_enqueue = lambda msg: None ctx = ScriptRunContext( "TestSessionID", fake_enqueue, "", SessionState(), UploadedFileManager(), ) ctx.on_script_start() msg = ForwardMsg() msg.page_config_changed.title = "foo" ctx.enqueue(msg) with self.assertRaises(StreamlitAPIException): same_msg = ForwardMsg() same_msg.page_config_changed.title = "bar" ctx.enqueue(same_msg)
def test_set_page_config_first(self): """st.set_page_config must be called before other st commands when the script has been marked as started""" fake_enqueue = lambda msg: None ctx = ScriptRunContext( "TestSessionID", fake_enqueue, "", SessionState(), UploadedFileManager(), ) ctx.on_script_start() markdown_msg = ForwardMsg() markdown_msg.delta.new_element.markdown.body = "foo" msg = ForwardMsg() msg.page_config_changed.title = "foo" ctx.enqueue(markdown_msg) with self.assertRaises(StreamlitAPIException): ctx.enqueue(msg)
class Server: _singleton: Optional["Server"] = None @classmethod def get_current(cls) -> "Server": """ Returns ------- Server The singleton Server object. """ if Server._singleton is None: raise RuntimeError("Server has not been initialized yet") return Server._singleton def __init__(self, ioloop: IOLoop, script_path: str, command_line: Optional[str]): """Create the server. It won't be started yet.""" if Server._singleton is not None: raise RuntimeError("Server already initialized. Use .get_current() instead") Server._singleton = self _set_tornado_log_levels() self._ioloop = ioloop self._script_path = script_path self._command_line = command_line # Mapping of ReportSession.id -> SessionInfo. self._session_info_by_id: Dict[str, SessionInfo] = {} self._must_stop = threading.Event() self._state = State.INITIAL self._message_cache = ForwardMsgCache() self._uploaded_file_mgr = UploadedFileManager() self._uploaded_file_mgr.on_files_updated.connect(self.on_files_updated) self._report: Optional[Report] = None self._preheated_session_id: Optional[str] = None def __repr__(self) -> str: return util.repr_(self) @property def script_path(self) -> str: return self._script_path def get_session_by_id(self, session_id: str) -> Optional[ReportSession]: """Return the ReportSession corresponding to the given id, or None if no such session exists.""" session_info = self._get_session_info(session_id) if session_info is None: return None return session_info.session def on_files_updated(self, session_id: str) -> None: """Event handler for UploadedFileManager.on_file_added. Ensures that uploaded files from stale sessions get deleted. """ session_info = self._get_session_info(session_id) if session_info is None: # If an uploaded file doesn't belong to an existing session, # remove it so it doesn't stick around forever. self._uploaded_file_mgr.remove_session_files(session_id) def _get_session_info(self, session_id: str) -> Optional[SessionInfo]: """Return the SessionInfo with the given id, or None if no such session exists. """ return self._session_info_by_id.get(session_id, None) def start(self, on_started: Callable[["Server"], Any]) -> None: """Start the server. Parameters ---------- on_started : callable A callback that will be called when the server's run-loop has started, and the server is ready to begin receiving clients. """ if self._state != State.INITIAL: raise RuntimeError("Server has already been started") LOGGER.debug("Starting server...") app = self._create_app() start_listening(app) port = config.get_option("server.port") LOGGER.debug("Server started on port %s", port) self._ioloop.spawn_callback(self._loop_coroutine, on_started) def get_debug(self) -> Dict[str, Dict[str, Any]]: if self._report: return {"report": self._report.get_debug()} return {} def _create_app(self) -> tornado.web.Application: """Create our tornado web app.""" base = config.get_option("server.baseUrlPath") routes = [ ( make_url_path_regex(base, "stream"), _BrowserWebSocketHandler, dict(server=self), ), ( make_url_path_regex(base, "healthz"), HealthHandler, dict(callback=lambda: self.is_ready_for_browser_connection), ), (make_url_path_regex(base, "debugz"), DebugHandler, dict(server=self)), (make_url_path_regex(base, "metrics"), MetricsHandler), ( make_url_path_regex(base, "message"), MessageCacheHandler, dict(cache=self._message_cache), ), ( make_url_path_regex( base, UPLOAD_FILE_ROUTE, ), UploadFileRequestHandler, dict( file_mgr=self._uploaded_file_mgr, get_session_info=self._get_session_info, ), ), ( make_url_path_regex(base, "assets/(.*)"), AssetsFileHandler, {"path": "%s/" % file_util.get_assets_dir()}, ), (make_url_path_regex(base, "media/(.*)"), MediaFileHandler, {"path": ""}), ( make_url_path_regex(base, "component/(.*)"), ComponentRequestHandler, dict(registry=ComponentRegistry.instance()), ), ] if config.get_option("server.scriptHealthCheckEnabled"): routes.extend( [ ( make_url_path_regex(base, "script-health-check"), HealthHandler, dict(callback=lambda: self.does_script_run_without_error()), ) ] ) if config.get_option("global.developmentMode"): LOGGER.debug("Serving static content from the Node dev server") else: static_path = file_util.get_static_dir() LOGGER.debug("Serving static content from %s", static_path) routes.extend( [ ( make_url_path_regex(base, "(.*)"), StaticFileHandler, {"path": "%s/" % static_path, "default_filename": "index.html"}, ), (make_url_path_regex(base, trailing_slash=False), AddSlashHandler), ] ) return tornado.web.Application( routes, # type: ignore[arg-type] cookie_secret=config.get_option("server.cookieSecret"), xsrf_cookies=config.get_option("server.enableXsrfProtection"), **TORNADO_SETTINGS, # type: ignore[arg-type] ) def _set_state(self, new_state: State) -> None: LOGGER.debug("Server state: %s -> %s" % (self._state, new_state)) self._state = new_state @property async def is_ready_for_browser_connection(self) -> Tuple[bool, str]: if self._state not in (State.INITIAL, State.STOPPING, State.STOPPED): return True, "ok" return False, "unavailable" async def does_script_run_without_error(self) -> Tuple[bool, str]: """Load and execute the app's script to verify it runs without an error. Returns ------- (True, "ok") if the script completes without error, or (False, err_msg) if the script raises an exception. """ session = ReportSession( ioloop=self._ioloop, script_path=self._script_path, command_line=self._command_line, uploaded_file_manager=self._uploaded_file_mgr, ) try: session.request_rerun(None) now = time.perf_counter() while ( SCRIPT_RUN_WITHOUT_ERRORS_KEY not in session.session_state and (time.perf_counter() - now) < SCRIPT_RUN_CHECK_TIMEOUT ): await tornado.gen.sleep(0.1) if SCRIPT_RUN_WITHOUT_ERRORS_KEY not in session.session_state: return False, "timeout" ok = session.session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY] msg = "ok" if ok else "error" return ok, msg finally: session.shutdown() @property def browser_is_connected(self) -> bool: return self._state == State.ONE_OR_MORE_BROWSERS_CONNECTED @property def is_running_hello(self) -> bool: from streamlit.hello import hello return self._script_path == hello.__file__ @tornado.gen.coroutine def _loop_coroutine( self, on_started: Optional[Callable[["Server"], Any]] = None ) -> Generator[Any, None, None]: try: if self._state == State.INITIAL: self._set_state(State.WAITING_FOR_FIRST_BROWSER) elif self._state == State.ONE_OR_MORE_BROWSERS_CONNECTED: pass else: raise RuntimeError("Bad server state at start: %s" % self._state) if on_started is not None: on_started(self) while not self._must_stop.is_set(): if self._state == State.WAITING_FOR_FIRST_BROWSER: pass elif self._state == State.ONE_OR_MORE_BROWSERS_CONNECTED: # Shallow-clone our sessions into a list, so we can iterate # over it and not worry about whether it's being changed # outside this coroutine. session_infos = list(self._session_info_by_id.values()) for session_info in session_infos: if session_info.ws is None: # Preheated. continue msg_list = session_info.session.flush_browser_queue() for msg in msg_list: try: self._send_message(session_info, msg) except tornado.websocket.WebSocketClosedError: self._close_report_session(session_info.session.id) yield yield elif self._state == State.NO_BROWSERS_CONNECTED: pass else: # Break out of the thread loop if we encounter any other state. break yield tornado.gen.sleep(0.01) # Shut down all ReportSessions for session_info in list(self._session_info_by_id.values()): session_info.session.shutdown() self._set_state(State.STOPPED) except Exception: # Can't just re-raise here because co-routines use Tornado # exceptions for control flow, which appears to swallow the reraised # exception. traceback.print_exc() LOGGER.info( """ Please report this bug at https://github.com/streamlit/streamlit/issues. """ ) finally: self._on_stopped() def _send_message(self, session_info: SessionInfo, msg: ForwardMsg) -> None: """Send a message to a client. If the client is likely to have already cached the message, we may instead send a "reference" message that contains only the hash of the message. Parameters ---------- session_info : SessionInfo The SessionInfo associated with websocket msg : ForwardMsg The message to send to the client """ msg.metadata.cacheable = is_cacheable_msg(msg) msg_to_send = msg if msg.metadata.cacheable: populate_hash_if_needed(msg) if self._message_cache.has_message_reference( msg, session_info.session, session_info.report_run_count ): # This session has probably cached this message. Send # a reference instead. LOGGER.debug("Sending cached message ref (hash=%s)" % msg.hash) msg_to_send = create_reference_msg(msg) # Cache the message so it can be referenced in the future. # If the message is already cached, this will reset its # age. LOGGER.debug("Caching message (hash=%s)" % msg.hash) self._message_cache.add_message( msg, session_info.session, session_info.report_run_count ) # If this was a `report_finished` message, we increment the # report_run_count for this session, and update the cache if ( msg.WhichOneof("type") == "report_finished" and msg.report_finished == ForwardMsg.FINISHED_SUCCESSFULLY ): LOGGER.debug( "Report finished successfully; " "removing expired entries from MessageCache " "(max_age=%s)", config.get_option("global.maxCachedMessageAge"), ) session_info.report_run_count += 1 self._message_cache.remove_expired_session_entries( session_info.session, session_info.report_run_count ) # Ship it off! if session_info.ws is not None: session_info.ws.write_message( serialize_forward_msg(msg_to_send), binary=True ) def stop(self) -> None: click.secho(" Stopping...", fg="blue") self._set_state(State.STOPPING) self._must_stop.set() def _on_stopped(self) -> None: """Called when our runloop is exiting, to shut down the ioloop. This will end our process. (Tests can patch this method out, to prevent the test's ioloop from being shutdown.) """ self._ioloop.stop() def add_preheated_report_session(self) -> None: """Register a fake browser with the server and run the script. This is used to start running the user's script even before the first browser connects. """ session = self._create_or_reuse_report_session(ws=None) session.handle_rerun_script_request(is_preheat=True) def _create_or_reuse_report_session( self, ws: Optional[WebSocketHandler] ) -> ReportSession: """Register a connected browser with the server. Parameters ---------- ws : _BrowserWebSocketHandler or None The newly-connected websocket handler or None if preheated connection. Returns ------- ReportSession The newly-created ReportSession for this browser connection. """ if self._preheated_session_id is not None: assert len(self._session_info_by_id) == 1 assert ws is not None session_id = self._preheated_session_id self._preheated_session_id = None session_info = self._session_info_by_id[session_id] session_info.ws = ws session = session_info.session LOGGER.debug( "Reused preheated session for ws %s. Session ID: %s", id(ws), session_id ) else: session = ReportSession( ioloop=self._ioloop, script_path=self._script_path, command_line=self._command_line, uploaded_file_manager=self._uploaded_file_mgr, ) LOGGER.debug( "Created new session for ws %s. Session ID: %s", id(ws), session.id ) assert session.id not in self._session_info_by_id, ( "session.id '%s' registered multiple times!" % session.id ) self._session_info_by_id[session.id] = SessionInfo(ws, session) if ws is None: self._preheated_session_id = session.id else: self._set_state(State.ONE_OR_MORE_BROWSERS_CONNECTED) return session def _close_report_session(self, session_id: str) -> None: """Shutdown and remove a ReportSession. This function may be called multiple times for the same session, which is not an error. (Subsequent calls just no-op.) Parameters ---------- session_id : str The ReportSession's id string. """ if session_id in self._session_info_by_id: session_info = self._session_info_by_id[session_id] del self._session_info_by_id[session_id] session_info.session.shutdown() if len(self._session_info_by_id) == 0: self._set_state(State.NO_BROWSERS_CONNECTED)
class UploadedFileManagerTest(unittest.TestCase): def setUp(self): self.mgr = UploadedFileManager() self.filemgr_events = [] self.mgr.on_files_updated.connect(self._on_files_updated) def _on_files_updated(self, file_list, **kwargs): self.filemgr_events.append(file_list) def test_add_file(self): self.assertIsNone(self.mgr.get_files("non-report", "non-widget")) self.mgr.add_files("session", "widget", [file1]) self.assertEqual([file1], self.mgr.get_files("session", "widget")) self.assertEqual(len(self.filemgr_events), 1) # Add another file with the same ID self.mgr.add_files("session", "widget", [file2]) self.assertEqual([file1, file2], self.mgr.get_files("session", "widget")) self.assertEqual(len(self.filemgr_events), 2) def test_remove_file(self): # This should not error. self.mgr.remove_files("non-report", "non-widget") self.mgr.add_files("session", "widget", [file1]) self.mgr.remove_file("session", "widget", file1.id) self.assertEqual([], self.mgr.get_files("session", "widget")) self.mgr.remove_file("session", "widget", file1.id) self.assertEqual([], self.mgr.get_files("session", "widget")) self.mgr.add_files("session", "widget", [file1]) self.mgr.add_files("session", "widget", [file2]) self.mgr.remove_file("session", "widget", file1.id) self.assertEqual([file2], self.mgr.get_files("session", "widget")) def test_remove_widget_files(self): # This should not error. self.mgr.remove_session_files("non-report") # Add two files with different session IDs, but the same widget ID. self.mgr.add_files("session1", "widget", [file1]) self.mgr.add_files("session2", "widget", [file1]) self.mgr.remove_files("session1", "widget") self.assertIsNone(self.mgr.get_files("session1", "widget")) self.assertEqual([file1], self.mgr.get_files("session2", "widget")) def test_remove_session_files(self): # This should not error. self.mgr.remove_session_files("non-report") # Add two files with different session IDs, but the same widget ID. self.mgr.add_files("session1", "widget1", [file1]) self.mgr.add_files("session1", "widget2", [file1]) self.mgr.add_files("session2", "widget", [file1]) self.mgr.remove_session_files("session1") self.assertIsNone(self.mgr.get_files("session1", "widget1")) self.assertIsNone(self.mgr.get_files("session1", "widget2")) self.assertEqual([file1], self.mgr.get_files("session2", "widget")) def test_replace_widget_files(self): self.mgr.add_files("session1", "widget", [file1]) self.mgr.replace_files("session1", "widget", [file2]) self.assertEqual(len(self.mgr.get_files("session1", "widget")), 1) self.assertEqual([file2], self.mgr.get_files("session1", "widget"))
def test_enqueue_new_session_message(self, patched_config): """The SCRIPT_STARTED event should enqueue a 'new_session' message.""" def get_option(name): if name == "server.runOnSave": # Just to avoid starting the watcher for no reason. return False return config.get_option(name) patched_config.get_option.side_effect = get_option patched_config.get_options_for_section.side_effect = ( _mock_get_options_for_section() ) # Create a AppSession with some mocked bits session = AppSession( ioloop=self.io_loop, session_data=SessionData("mock_report.py", ""), uploaded_file_manager=UploadedFileManager(), message_enqueued_callback=lambda: None, local_sources_watcher=MagicMock(), ) orig_ctx = get_script_run_ctx() ctx = ScriptRunContext( session_id="TestSessionID", enqueue=session._session_data.enqueue, query_string="", session_state=MagicMock(), uploaded_file_mgr=MagicMock(), ) add_script_run_ctx(ctx=ctx) # Send a mock SCRIPT_STARTED event. session._on_scriptrunner_event( sender=MagicMock(), event=ScriptRunnerEvent.SCRIPT_STARTED ) sent_messages = session._session_data._browser_queue._queue self.assertEqual(2, len(sent_messages)) # NewApp and SessionState messages # Note that we're purposefully not very thoroughly testing new_session # fields below to avoid getting to the point where we're just # duplicating code in tests. new_session_msg = sent_messages[0].new_session self.assertEqual("mock_scriptrun_id", new_session_msg.script_run_id) self.assertTrue(new_session_msg.HasField("config")) self.assertEqual( config.get_option("server.allowRunOnSave"), new_session_msg.config.allow_run_on_save, ) self.assertTrue(new_session_msg.HasField("custom_theme")) self.assertEqual("black", new_session_msg.custom_theme.text_color) init_msg = new_session_msg.initialize self.assertTrue(init_msg.HasField("user_info")) add_script_run_ctx(ctx=orig_ctx)
class UploadFileRequestHandlerTest(tornado.testing.AsyncHTTPTestCase): """Tests the /upload_file endpoint.""" def get_app(self): self.file_mgr = UploadedFileManager() return tornado.web.Application([("/upload_file", UploadFileRequestHandler, dict(file_mgr=self.file_mgr))]) def _upload_files(self, params): # We use requests.Request to construct our multipart/form-data request # here, because they are absurdly fiddly to compose, and Tornado # doesn't include a utility for building them. We then use self.fetch() # to actually send the request to the test server. req = requests.Request(method="POST", url=self.get_url("/upload_file"), files=params).prepare() return self.fetch("/upload_file", method=req.method, headers=req.headers, body=req.body) def test_upload_one_file(self): """Uploading a file should populate our file_mgr.""" file = UploadedFile("image.png", b"123") params = { file.name: file, "sessionId": (None, "fooReport"), "widgetId": (None, "barWidget"), } response = self._upload_files(params) self.assertEqual(200, response.code) self.assertEqual([file], self.file_mgr.get_files("fooReport", "barWidget")) def test_upload_multiple_files(self): file1 = UploadedFile("image1.png", b"123") file2 = UploadedFile("image2.png", b"456") file3 = UploadedFile("image3.png", b"789") params = { file1.name: file1, file2.name: file2, file3.name: file3, "sessionId": (None, "fooReport"), "widgetId": (None, "barWidget"), } response = self._upload_files(params) self.assertEqual(200, response.code) self.assertEqual( sorted([file1, file2, file3], key=_get_filename), sorted(self.file_mgr.get_files("fooReport", "barWidget"), key=_get_filename), ) def test_missing_params(self): """Missing params in the body should fail with 400 status.""" params = { "image.png": ("image.png", b"1234"), "sessionId": (None, "fooReport"), # "widgetId": (None, 'barWidget'), } response = self._upload_files(params) self.assertEqual(400, response.code) self.assertIn("Missing 'widgetId'", response.reason) def test_missing_file(self): """Missing file should fail with 400 status.""" params = { # "image.png": ("image.png", b"1234"), "sessionId": (None, "fooReport"), "widgetId": (None, "barWidget"), } response = self._upload_files(params) self.assertEqual(400, response.code) self.assertIn("Expected at least 1 file, but got 0", response.reason)
class UploadFileRequestHandlerTest(tornado.testing.AsyncHTTPTestCase): """Tests the /upload_file endpoint.""" def get_app(self): self.file_mgr = UploadedFileManager() self._get_session_info = lambda x: True return tornado.web.Application([ ( UPLOAD_FILE_ROUTE, UploadFileRequestHandler, dict( file_mgr=self.file_mgr, get_session_info=self._get_session_info, ), ), ]) def _upload_files(self, params): # We use requests.Request to construct our multipart/form-data request # here, because they are absurdly fiddly to compose, and Tornado # doesn't include a utility for building them. We then use self.fetch() # to actually send the request to the test server. req = requests.Request(method="POST", url=self.get_url("/upload_file"), files=params).prepare() return self.fetch("/upload_file", method=req.method, headers=req.headers, body=req.body) def test_upload_one_file(self): """Uploading a file should populate our file_mgr.""" file = MockFile("filename", b"123") params = { file.name: file.data, "sessionId": (None, "fooReport"), "widgetId": (None, "barWidget"), "totalFiles": (None, "1"), } response = self._upload_files(params) self.assertEqual(200, response.code) self.assertEqual( [(file.name, file.data)], [(rec.name, rec.data) for rec in self.file_mgr.get_files("fooReport", "barWidget")], ) def test_upload_multiple_files(self): file_1 = MockFile("file1", b"123") file_2 = MockFile("file2", b"456") file_3 = MockFile("file3", b"789") params = { file_1.name: file_1.data, file_2.name: file_2.data, file_3.name: file_3.data, "sessionId": (None, "fooReport"), "widgetId": (None, "barWidget"), "totalFiles": (None, "1"), } response = self._upload_files(params) self.assertEqual(200, response.code) self.assertEqual( sorted([file_1, file_2, file_3]), sorted([ (rec.name, rec.data) for rec in self.file_mgr.get_files("fooReport", "barWidget") ]), ) def test_upload_missing_params(self): """Missing params in the body should fail with 400 status.""" params = { "image.png": ("image.png", b"1234"), "sessionId": (None, "fooReport"), # "widgetId": (None, 'barWidget'), "totalFiles": (None, "1"), } response = self._upload_files(params) self.assertEqual(400, response.code) self.assertIn("Missing 'widgetId'", response.reason) def test_upload_missing_file(self): """Missing file should fail with 400 status.""" params = { # "image.png": ("image.png", b"1234"), "sessionId": (None, "fooReport"), "widgetId": (None, "barWidget"), "totalFiles": (None, "1"), } response = self._upload_files(params) self.assertEqual(400, response.code) self.assertIn("Expected at least 1 file, but got 0", response.reason) def test_delete_file(self): """File should be able to be deleted successfully""" file1 = UploadedFileRec("1234", "name", "type", b"1234") file2 = UploadedFileRec("4567", "name", "type", b"1234") self.file_mgr.add_files("session1", "widget1", [file1]) self.file_mgr.add_files("session2", "widget2", [file2]) response = self.fetch(f"/upload_file/session1/widget1/1234", method="DELETE") self.assertEqual(200, response.code) self.assertFalse(len(self.file_mgr.get_files("session1", "widget1"))) self.assertTrue(len(self.file_mgr.get_files("session2", "widget2"))) def test_delete_file_across_sessions(self): """Deleting file param mismatch should fail with 404 status.""" file1 = UploadedFileRec("1234", "name", "type", b"1234") file2 = UploadedFileRec("4567", "name", "type", b"1234") self.file_mgr.add_files("session1", "widget1", [file1]) self.file_mgr.add_files("session2", "widget2", [file2]) response = self.fetch(f"/upload_file/session2/widget1/1234", method="DELETE") self.assertEqual(404, response.code) self.assertTrue(len(self.file_mgr.get_files("session1", "widget1"))) self.assertTrue(len(self.file_mgr.get_files("session2", "widget2"))) @parameterized.expand([ (None, "widget_id", "123"), ("session_id", None, "123"), ("session_id", "widget_id", None), ]) def test_delete_missing_param(self, session_id, widget_id, file_id): """Missing param should fail with 404 status.""" response = self.fetch( f"/upload_file/{session_id}/{widget_id}/{file_id}", method="DELETE") self.assertEqual(404, response.code)
class UploadFileRequestHandlerInvalidSessionTest( tornado.testing.AsyncHTTPTestCase): """Tests the /upload_file endpoint.""" def get_app(self): self.file_mgr = UploadedFileManager() self._get_session_info = lambda x: None return tornado.web.Application([ ( UPLOAD_FILE_ROUTE, UploadFileRequestHandler, dict( file_mgr=self.file_mgr, get_session_info=self._get_session_info, ), ), ]) def _upload_files(self, params): # We use requests.Request to construct our multipart/form-data request # here, because they are absurdly fiddly to compose, and Tornado # doesn't include a utility for building them. We then use self.fetch() # to actually send the request to the test server. req = requests.Request(method="POST", url=self.get_url("/upload_file"), files=params).prepare() return self.fetch("/upload_file", method=req.method, headers=req.headers, body=req.body) def test_upload_one_file(self): """Uploading a file should populate our file_mgr.""" file = MockFile("filename", b"123") params = { file.name: file.data, "sessionId": (None, "fooReport"), "widgetId": (None, "barWidget"), "totalFiles": (None, "1"), } response = self._upload_files(params) self.assertEqual(400, response.code) self.assertIsNone(self.file_mgr.get_files("fooReport", "barWidget")) def test_upload_multiple_files(self): file_1 = MockFile("file1", b"123") file_2 = MockFile("file2", b"456") file_3 = MockFile("file3", b"789") params = { file_1.name: file_1.data, file_2.name: file_2.data, file_3.name: file_3.data, "sessionId": (None, "fooReport"), "widgetId": (None, "barWidget"), "totalFiles": (None, "1"), } response = self._upload_files(params) self.assertEqual(400, response.code) self.assertIsNone(self.file_mgr.get_files("fooReport", "barWidget")) def test_delete_file(self): """File should be able to be deleted successfully""" file1 = UploadedFileRec("1234", "name", "type", b"1234") file2 = UploadedFileRec("4567", "name", "type", b"1234") self.file_mgr.add_files("session1", "widget1", [file1]) self.file_mgr.add_files("session2", "widget2", [file2]) response = self.fetch(f"/upload_file/session1/widget1/1234", method="DELETE") self.assertEqual(404, response.code)
class UploadedFileManagerTest(unittest.TestCase): def setUp(self): self.mgr = UploadedFileManager() self.filemgr_events = [] self.mgr.on_files_updated.connect(self._on_files_updated) def _on_files_updated(self, file_list, **kwargs): self.filemgr_events.append(file_list) def test_added_file_id(self): """An added file should have a unique ID.""" f1 = self.mgr.add_file("session", "widget", FILE_1) f2 = self.mgr.add_file("session", "widget", FILE_1) self.assertNotEqual(FILE_1.id, f1.id) self.assertNotEqual(f1.id, f2.id) def test_added_file_properties(self): """An added file should maintain all its source properties except its ID.""" added = self.mgr.add_file("session", "widget", FILE_1) self.assertNotEqual(added.id, FILE_1.id) self.assertEqual(added.name, FILE_1.name) self.assertEqual(added.type, FILE_1.type) self.assertEqual(added.data, FILE_1.data) def test_retrieve_added_file(self): """After adding a file to the mgr, we should be able to get it back.""" self.assertEqual([], self.mgr.get_all_files("non-report", "non-widget")) file_1 = self.mgr.add_file("session", "widget", FILE_1) self.assertEqual([file_1], self.mgr.get_all_files("session", "widget")) self.assertEqual([file_1], self.mgr.get_files("session", "widget", [file_1.id])) self.assertEqual(len(self.filemgr_events), 1) # Add another file file_2 = self.mgr.add_file("session", "widget", FILE_2) self.assertEqual([file_1, file_2], self.mgr.get_all_files("session", "widget")) self.assertEqual([file_1], self.mgr.get_files("session", "widget", [file_1.id])) self.assertEqual([file_2], self.mgr.get_files("session", "widget", [file_2.id])) self.assertEqual(len(self.filemgr_events), 2) def test_remove_file(self): # This should not error. self.mgr.remove_files("non-report", "non-widget") f1 = self.mgr.add_file("session", "widget", FILE_1) self.mgr.remove_file("session", "widget", f1.id) self.assertEqual([], self.mgr.get_all_files("session", "widget")) # Remove the file again. It doesn't exist, but this isn't an error. self.mgr.remove_file("session", "widget", f1.id) self.assertEqual([], self.mgr.get_all_files("session", "widget")) f1 = self.mgr.add_file("session", "widget", FILE_1) f2 = self.mgr.add_file("session", "widget", FILE_2) self.mgr.remove_file("session", "widget", f1.id) self.assertEqual([f2], self.mgr.get_all_files("session", "widget")) def test_remove_widget_files(self): # This should not error. self.mgr.remove_session_files("non-report") # Add two files with different session IDs, but the same widget ID. f1 = self.mgr.add_file("session1", "widget", FILE_1) f2 = self.mgr.add_file("session2", "widget", FILE_1) self.mgr.remove_files("session1", "widget") self.assertEqual([], self.mgr.get_all_files("session1", "widget")) self.assertEqual([f2], self.mgr.get_all_files("session2", "widget")) def test_remove_session_files(self): # This should not error. self.mgr.remove_session_files("non-report") # Add two files with different session IDs, but the same widget ID. f1 = self.mgr.add_file("session1", "widget1", FILE_1) f2 = self.mgr.add_file("session1", "widget2", FILE_1) f3 = self.mgr.add_file("session2", "widget", FILE_1) self.mgr.remove_session_files("session1") self.assertEqual([], self.mgr.get_all_files("session1", "widget1")) self.assertEqual([], self.mgr.get_all_files("session1", "widget2")) self.assertEqual([f3], self.mgr.get_all_files("session2", "widget")) def test_remove_orphaned_files(self): """Test the remove_orphaned_files behavior""" f1 = self.mgr.add_file("session1", "widget1", FILE_1) f2 = self.mgr.add_file("session1", "widget1", FILE_1) f3 = self.mgr.add_file("session1", "widget1", FILE_1) self.assertEqual([f1, f2, f3], self.mgr.get_all_files("session1", "widget1")) # Nothing should be removed here (all files are active). self.mgr.remove_orphaned_files( "session1", "widget1", newest_file_id=f3.id, active_file_ids=[f1.id, f2.id, f3.id], ) self.assertEqual([f1, f2, f3], self.mgr.get_all_files("session1", "widget1")) # Nothing should be removed here (no files are active, but they're all # "newer" than newest_file_id). self.mgr.remove_orphaned_files("session1", "widget1", newest_file_id=f1.id - 1, active_file_ids=[]) self.assertEqual([f1, f2, f3], self.mgr.get_all_files("session1", "widget1")) # f2 should be removed here (it's not in the active file list) self.mgr.remove_orphaned_files("session1", "widget1", newest_file_id=f3.id, active_file_ids=[f1.id, f3.id]) self.assertEqual([f1, f3], self.mgr.get_all_files("session1", "widget1")) # remove_orphaned_files on an untracked session/widget should not error self.mgr.remove_orphaned_files("no_session", "no_widget", newest_file_id=0, active_file_ids=[])
def setUp(self): self.mgr = UploadedFileManager() self.filemgr_events = [] self.mgr.on_files_updated.connect(self._on_files_updated)