def test_handle_save_request(self, _1): """Test that handle_save_request serializes files correctly.""" # Create a ReportSession with some mocked bits rs = ReportSession( self.io_loop, "mock_report.py", "", UploadedFileManager(), None ) rs._report.report_id = "TestReportID" orig_ctx = get_report_ctx() ctx = ReportContext( "TestSessionID", rs._report.enqueue, "", SessionState(), UploadedFileManager(), ) add_report_ctx(ctx=ctx) rs._scriptrunner = MagicMock() storage = MockStorage() rs._storage = storage # Send two deltas: empty and markdown st.empty() st.markdown("Text!") yield rs.handle_save_request(_create_mock_websocket()) # Check the order of the received files. Manifest should be last. self.assertEqual(3, len(storage.files)) self.assertEqual("reports/TestReportID/0.pb", storage.get_filename(0)) self.assertEqual("reports/TestReportID/1.pb", storage.get_filename(1)) self.assertEqual("reports/TestReportID/manifest.pb", storage.get_filename(2)) # Check the manifest manifest = storage.get_message(2, StaticManifest) self.assertEqual("mock_report", manifest.name) self.assertEqual(2, manifest.num_messages) self.assertEqual(StaticManifest.DONE, manifest.server_status) # Check that the deltas we sent match messages in storage sent_messages = rs._report._master_queue._queue received_messages = [ storage.get_message(0, ForwardMsg), storage.get_message(1, ForwardMsg), ] self.assertEqual(sent_messages, received_messages) add_report_ctx(ctx=orig_ctx)
def test_enqueue_new_report_message(self, _1, _2, patched_config): def get_option(name): if name == "server.runOnSave": # Just to avoid starting the watcher for no reason. return False return config.get_option(name) patched_config.get_option.side_effect = get_option patched_config.get_options_for_section.side_effect = ( _mock_get_options_for_section()) # Create a ReportSession with some mocked bits rs = ReportSession(self.io_loop, "mock_report.py", "", UploadedFileManager()) rs._report.report_id = "testing _enqueue_new_report" orig_ctx = get_report_ctx() ctx = ReportContext("TestSessionID", rs._report.enqueue, "", None, None) add_report_ctx(ctx=ctx) rs._on_scriptrunner_event(ScriptRunnerEvent.SCRIPT_STARTED) sent_messages = rs._report._master_queue._queue self.assertEqual(len(sent_messages), 2) # NewReport and SessionState messages # Note that we're purposefully not very thoroughly testing new_report # fields below to avoid getting to the point where we're just # duplicating code in tests. new_report_msg = sent_messages[0].new_report self.assertEqual(new_report_msg.report_id, rs._report.report_id) self.assertEqual(new_report_msg.HasField("config"), True) self.assertEqual( new_report_msg.config.allow_run_on_save, config.get_option("server.allowRunOnSave"), ) self.assertEqual(new_report_msg.HasField("custom_theme"), True) self.assertEqual(new_report_msg.custom_theme.text_color, "black") init_msg = new_report_msg.initialize self.assertEqual(init_msg.HasField("user_info"), True) add_report_ctx(ctx=orig_ctx)
def setUp(self, override_root=True): self.report_queue = ReportQueue() self.override_root = override_root self.orig_report_ctx = None if self.override_root: self.orig_report_ctx = get_report_ctx() add_report_ctx( threading.current_thread(), ReportContext( session_id="test session id", enqueue=self.report_queue.enqueue, query_string="", widgets=Widgets(), uploaded_file_mgr=UploadedFileManager(), ), )
def tearDown(self): self.clear_queue() if self.override_root: add_report_ctx(threading.current_thread(), self.orig_report_ctx)
def _wrapped(session_state_ref: 'ReferenceType[_SessionState]', cb_ref: Union[str, List[Tuple[Callable[..., Any], 'ReferenceType[_SessionState]']]], need_report: bool = False, delegate_stop: bool = True, args: Optional[List[Any]] = None, kwargs: Optional[Dict[Any, Any]] = None): if args is None: args = [] if kwargs is None: kwargs = {} session_state = session_state_ref() if session_state is None: if delegate_stop: raise StopException("No session state") return fun, function_ctx_ref = None, None if isinstance(cb_ref, str): fun, function_ctx_ref = session_state.callbacks.get( cb_ref, (None, None)) elif len(cb_ref): fun, function_ctx_ref = cb_ref[0] if len(cb_ref) else (None, None) if fun is None: raise StopException( "Deleted function, probably not delegated stop, or not handled") function_ctx = function_ctx_ref() thread = threading.current_thread() orig_ctx = getattr(thread, REPORT_CONTEXT_ATTR_NAME, None) set_other_ctx = False rsession = session_state.get_session() rerun = None try: if function_ctx is None: raise StopException("No function context") rstate = session_state.get_session_state() current_ctx = session_state.get_ctx() if current_ctx != function_ctx: raise StopException("Other context") add_report_ctx(thread=thread, ctx=current_ctx) set_other_ctx = True need_report_not_running = False try: if need_report: if rstate == ReportSessionState.REPORT_NOT_RUNNING: need_report_not_running = True _SessionState.set_state( rsession, ReportSessionState.REPORT_IS_RUNNING) fun(*args, **kwargs) finally: if need_report_not_running and _SessionState.get_report_thread( rsession) is None: _SessionState.set_state(rsession, ReportSessionState.REPORT_NOT_RUNNING) except StopException as e: if isinstance(e, _RerunAndStopException): rerun = e.rerun_data if isinstance(cb_ref, str): del session_state.callbacks[cb_ref] else: del cb_ref[0] if delegate_stop: raise except RerunException as e: rerun = e.rerun_data except BaseException: import traceback traceback.print_exc() raise finally: if rerun: rsession.request_rerun(rerun) if set_other_ctx: if orig_ctx is None: delattr(thread, REPORT_CONTEXT_ATTR_NAME) else: add_report_ctx(thread=thread, ctx=orig_ctx)
print(json.dumps(options.as_dict(), indent=4)) # Create a queue that will process st.writes generated # from other threads stats_queue = queue.Queue() training_complete = False training_start_time = time.time() chart = st.altair_chart(create_chart()) def stats_worker(): while not training_complete: entry: StatsEntry = stats_queue.get() if "Environment/Cumulative Reward" in entry.values: stats_summary = entry.values["Environment/Cumulative Reward"] chart.add_rows( pd.DataFrame({ "step": [entry.step], "mean": [stats_summary.mean] })) stats_queue.task_done() command_thread = threading.Thread(target=stats_worker) add_report_ctx(command_thread) command_thread.start() streamlit_learn.run_training(run_seed, options, stats_queue) training_complete = True command_thread.join()
def run_experiment(): ''' The interface to setup the estimator, configuration, data loading, etc. is nearly identical to a Jupyter Notebook interface for Sapsan. In an ideal case, this is the only function you need to edit to set up your own GUI demo. ''' if widget_values['backend_selection'] == 'Fake': tracking_backend = FakeBackend(widget_values['experiment name']) elif widget_values['backend_selection'] == 'MLflow': tracking_backend = MLflowBackend(widget_values['experiment name'], widget_values['mlflow_host'], widget_values['mlflow_port']) #Load the data data_loader = load_data(widget_values['checkpoints']) x, y = data_loader.load_numpy() y = flatten(y) loaders = data_loader.convert_to_torch([x, y]) st.write("Dataset loaded...") estim = Estimator(config=EstimatorConfig( n_epochs=int(widget_values['n_epochs']), patience=int(widget_values['patience']), min_delta=float(widget_values['min_delta'])), loaders=loaders) #Set the experiment training_experiment = Train(backend=tracking_backend, model=estim, data_parameters=data_loader, show_log=False) #Plot progress progress_slot = st.empty() epoch_slot = st.empty() thread = Thread(target=show_log, args=(progress_slot, epoch_slot)) add_report_ctx(thread) thread.start() start = time.time() #Train the model trained_estimator = training_experiment.run() st.write('Trained in %.2f sec' % ((time.time() - start))) st.success('Done! Plotting...') #--- Test the model --- #Load the test data data_loader = load_data(widget_values['checkpoint_test']) x, y = data_loader.load_numpy() loaders = data_loader.convert_to_torch([x, y]) #Set the test experiment trained_estimator.loaders = loaders evaluation_experiment = Evaluate(backend=tracking_backend, model=trained_estimator, data_parameters=data_loader) #Test the model cubes = evaluation_experiment.run() #Plot PDF, CDF, and slices #Similar setup to replot from sapsan.Evaluate() mpl.rcParams.update(plot_params()) fig = plt.figure(figsize=(12, 6), dpi=60) (ax1, ax2) = fig.subplots(1, 2) pdf_plot([cubes['pred_cube'], cubes['target_cube']], names=['prediction', 'target'], ax=ax1) cdf_plot([cubes['pred_cube'], cubes['target_cube']], names=['prediction', 'target'], ax=ax2) plot_static() slices_cubes = evaluation_experiment.split_batch(cubes['pred_cube']) slice_plot([slices_cubes['pred_slice'], slices_cubes['target_slice']], names=['prediction', 'target'], cmap=evaluation_experiment.cmap) st.pyplot(plt)
def stylize_image(args, ctx): add_report_ctx(threading.currentThread(), ctx) img = las.stylize(args[1], args[2], args[3]) return args[0], img
project = st.selectbox("Choose a project", list(projects.keys())) # When new category, this handle the errors try: proj_id = projects[project] except: pass # Start pomodoro if st.button("Start Pomodoro"): # TO DO: TRY MULTIPROCESS thread = threading.Thread(target=run_pomodoro, args=[pomodoro_queries]) add_report_ctx(thread) thread.start() # Retrieve the hour where the pomodoro started hour = st.experimental_get_query_params() try: hour = hour["starting_hour"] write_time = st.empty() write_time.info(f"Pomodoro started at: {hour[0]}") except: pass sel_options = st.empty() send_btn = st.empty() selections = {"Good": 1, "Bad": 2}