def test_handle_save_request(self, _1):
        """Test that handle_save_request serializes files correctly."""
        # Create a ReportSession with some mocked bits
        rs = ReportSession(
            self.io_loop, "mock_report.py", "", UploadedFileManager(), None
        )
        rs._report.report_id = "TestReportID"

        orig_ctx = get_report_ctx()
        ctx = ReportContext(
            "TestSessionID",
            rs._report.enqueue,
            "",
            SessionState(),
            UploadedFileManager(),
        )
        add_report_ctx(ctx=ctx)

        rs._scriptrunner = MagicMock()

        storage = MockStorage()
        rs._storage = storage

        # Send two deltas: empty and markdown
        st.empty()
        st.markdown("Text!")

        yield rs.handle_save_request(_create_mock_websocket())

        # Check the order of the received files. Manifest should be last.
        self.assertEqual(3, len(storage.files))
        self.assertEqual("reports/TestReportID/0.pb", storage.get_filename(0))
        self.assertEqual("reports/TestReportID/1.pb", storage.get_filename(1))
        self.assertEqual("reports/TestReportID/manifest.pb", storage.get_filename(2))

        # Check the manifest
        manifest = storage.get_message(2, StaticManifest)
        self.assertEqual("mock_report", manifest.name)
        self.assertEqual(2, manifest.num_messages)
        self.assertEqual(StaticManifest.DONE, manifest.server_status)

        # Check that the deltas we sent match messages in storage
        sent_messages = rs._report._master_queue._queue
        received_messages = [
            storage.get_message(0, ForwardMsg),
            storage.get_message(1, ForwardMsg),
        ]

        self.assertEqual(sent_messages, received_messages)

        add_report_ctx(ctx=orig_ctx)
    def test_enqueue_new_report_message(self, _1, _2, patched_config):
        def get_option(name):
            if name == "server.runOnSave":
                # Just to avoid starting the watcher for no reason.
                return False

            return config.get_option(name)

        patched_config.get_option.side_effect = get_option
        patched_config.get_options_for_section.side_effect = (
            _mock_get_options_for_section())

        # Create a ReportSession with some mocked bits
        rs = ReportSession(self.io_loop, "mock_report.py", "",
                           UploadedFileManager())
        rs._report.report_id = "testing _enqueue_new_report"

        orig_ctx = get_report_ctx()
        ctx = ReportContext("TestSessionID", rs._report.enqueue, "", None,
                            None)
        add_report_ctx(ctx=ctx)

        rs._on_scriptrunner_event(ScriptRunnerEvent.SCRIPT_STARTED)

        sent_messages = rs._report._master_queue._queue
        self.assertEqual(len(sent_messages),
                         2)  # NewReport and SessionState messages

        # Note that we're purposefully not very thoroughly testing new_report
        # fields below to avoid getting to the point where we're just
        # duplicating code in tests.
        new_report_msg = sent_messages[0].new_report
        self.assertEqual(new_report_msg.report_id, rs._report.report_id)

        self.assertEqual(new_report_msg.HasField("config"), True)
        self.assertEqual(
            new_report_msg.config.allow_run_on_save,
            config.get_option("server.allowRunOnSave"),
        )

        self.assertEqual(new_report_msg.HasField("custom_theme"), True)
        self.assertEqual(new_report_msg.custom_theme.text_color, "black")

        init_msg = new_report_msg.initialize
        self.assertEqual(init_msg.HasField("user_info"), True)

        add_report_ctx(ctx=orig_ctx)
Exemple #3
0
    def setUp(self, override_root=True):
        self.report_queue = ReportQueue()
        self.override_root = override_root
        self.orig_report_ctx = None

        if self.override_root:
            self.orig_report_ctx = get_report_ctx()
            add_report_ctx(
                threading.current_thread(),
                ReportContext(
                    session_id="test session id",
                    enqueue=self.report_queue.enqueue,
                    query_string="",
                    widgets=Widgets(),
                    uploaded_file_mgr=UploadedFileManager(),
                ),
            )
Exemple #4
0
 def tearDown(self):
     self.clear_queue()
     if self.override_root:
         add_report_ctx(threading.current_thread(), self.orig_report_ctx)
Exemple #5
0
def _wrapped(session_state_ref: 'ReferenceType[_SessionState]',
             cb_ref: Union[str, List[Tuple[Callable[..., Any],
                                           'ReferenceType[_SessionState]']]],
             need_report: bool = False,
             delegate_stop: bool = True,
             args: Optional[List[Any]] = None,
             kwargs: Optional[Dict[Any, Any]] = None):
    if args is None:
        args = []
    if kwargs is None:
        kwargs = {}

    session_state = session_state_ref()
    if session_state is None:
        if delegate_stop:
            raise StopException("No session state")
        return

    fun, function_ctx_ref = None, None
    if isinstance(cb_ref, str):
        fun, function_ctx_ref = session_state.callbacks.get(
            cb_ref, (None, None))
    elif len(cb_ref):
        fun, function_ctx_ref = cb_ref[0] if len(cb_ref) else (None, None)
    if fun is None:
        raise StopException(
            "Deleted function, probably not delegated stop, or not handled")
    function_ctx = function_ctx_ref()

    thread = threading.current_thread()
    orig_ctx = getattr(thread, REPORT_CONTEXT_ATTR_NAME, None)
    set_other_ctx = False
    rsession = session_state.get_session()
    rerun = None
    try:
        if function_ctx is None:
            raise StopException("No function context")

        rstate = session_state.get_session_state()
        current_ctx = session_state.get_ctx()
        if current_ctx != function_ctx:
            raise StopException("Other context")

        add_report_ctx(thread=thread, ctx=current_ctx)
        set_other_ctx = True
        need_report_not_running = False
        try:
            if need_report:
                if rstate == ReportSessionState.REPORT_NOT_RUNNING:
                    need_report_not_running = True
                    _SessionState.set_state(
                        rsession, ReportSessionState.REPORT_IS_RUNNING)
            fun(*args, **kwargs)
        finally:
            if need_report_not_running and _SessionState.get_report_thread(
                    rsession) is None:
                _SessionState.set_state(rsession,
                                        ReportSessionState.REPORT_NOT_RUNNING)
    except StopException as e:
        if isinstance(e, _RerunAndStopException):
            rerun = e.rerun_data

        if isinstance(cb_ref, str):
            del session_state.callbacks[cb_ref]
        else:
            del cb_ref[0]
        if delegate_stop:
            raise
    except RerunException as e:
        rerun = e.rerun_data
    except BaseException:
        import traceback
        traceback.print_exc()
        raise
    finally:
        if rerun:
            rsession.request_rerun(rerun)

        if set_other_ctx:
            if orig_ctx is None:
                delattr(thread, REPORT_CONTEXT_ATTR_NAME)
            else:
                add_report_ctx(thread=thread, ctx=orig_ctx)
Exemple #6
0
    print(json.dumps(options.as_dict(), indent=4))

    # Create a queue that will process st.writes generated
    # from other threads
    stats_queue = queue.Queue()
    training_complete = False
    training_start_time = time.time()

    chart = st.altair_chart(create_chart())

    def stats_worker():
        while not training_complete:
            entry: StatsEntry = stats_queue.get()
            if "Environment/Cumulative Reward" in entry.values:
                stats_summary = entry.values["Environment/Cumulative Reward"]
                chart.add_rows(
                    pd.DataFrame({
                        "step": [entry.step],
                        "mean": [stats_summary.mean]
                    }))
            stats_queue.task_done()

    command_thread = threading.Thread(target=stats_worker)
    add_report_ctx(command_thread)
    command_thread.start()

    streamlit_learn.run_training(run_seed, options, stats_queue)

    training_complete = True
    command_thread.join()
Exemple #7
0
    def run_experiment():
        '''
        The interface to setup the estimator, configuration, data loading, etc. is
        nearly identical to a Jupyter Notebook interface for Sapsan. In an ideal case,
        this is the only function you need to edit to set up your own GUI demo.
        '''

        if widget_values['backend_selection'] == 'Fake':
            tracking_backend = FakeBackend(widget_values['experiment name'])

        elif widget_values['backend_selection'] == 'MLflow':
            tracking_backend = MLflowBackend(widget_values['experiment name'],
                                             widget_values['mlflow_host'],
                                             widget_values['mlflow_port'])

        #Load the data
        data_loader = load_data(widget_values['checkpoints'])
        x, y = data_loader.load_numpy()
        y = flatten(y)
        loaders = data_loader.convert_to_torch([x, y])

        st.write("Dataset loaded...")

        estim = Estimator(config=EstimatorConfig(
            n_epochs=int(widget_values['n_epochs']),
            patience=int(widget_values['patience']),
            min_delta=float(widget_values['min_delta'])),
                          loaders=loaders)

        #Set the experiment
        training_experiment = Train(backend=tracking_backend,
                                    model=estim,
                                    data_parameters=data_loader,
                                    show_log=False)

        #Plot progress
        progress_slot = st.empty()
        epoch_slot = st.empty()

        thread = Thread(target=show_log, args=(progress_slot, epoch_slot))
        add_report_ctx(thread)
        thread.start()

        start = time.time()
        #Train the model
        trained_estimator = training_experiment.run()

        st.write('Trained in %.2f sec' % ((time.time() - start)))
        st.success('Done! Plotting...')

        #--- Test the model ---
        #Load the test data
        data_loader = load_data(widget_values['checkpoint_test'])
        x, y = data_loader.load_numpy()
        loaders = data_loader.convert_to_torch([x, y])

        #Set the test experiment
        trained_estimator.loaders = loaders
        evaluation_experiment = Evaluate(backend=tracking_backend,
                                         model=trained_estimator,
                                         data_parameters=data_loader)

        #Test the model
        cubes = evaluation_experiment.run()

        #Plot PDF, CDF, and slices
        #Similar setup to replot from sapsan.Evaluate()
        mpl.rcParams.update(plot_params())

        fig = plt.figure(figsize=(12, 6), dpi=60)
        (ax1, ax2) = fig.subplots(1, 2)

        pdf_plot([cubes['pred_cube'], cubes['target_cube']],
                 names=['prediction', 'target'],
                 ax=ax1)
        cdf_plot([cubes['pred_cube'], cubes['target_cube']],
                 names=['prediction', 'target'],
                 ax=ax2)
        plot_static()

        slices_cubes = evaluation_experiment.split_batch(cubes['pred_cube'])
        slice_plot([slices_cubes['pred_slice'], slices_cubes['target_slice']],
                   names=['prediction', 'target'],
                   cmap=evaluation_experiment.cmap)
        st.pyplot(plt)
Exemple #8
0
def stylize_image(args, ctx):
    add_report_ctx(threading.currentThread(), ctx)
    img = las.stylize(args[1], args[2], args[3])
    return args[0], img
Exemple #9
0
            project = st.selectbox("Choose a project", list(projects.keys()))

        # When new category, this handle the errors
        try:
            proj_id = projects[project]
        except:
            pass

        # Start pomodoro
        if st.button("Start Pomodoro"):

            # TO DO: TRY MULTIPROCESS

            thread = threading.Thread(target=run_pomodoro,
                                      args=[pomodoro_queries])
            add_report_ctx(thread)
            thread.start()

        # Retrieve the hour where the pomodoro started
        hour = st.experimental_get_query_params()

        try:
            hour = hour["starting_hour"]
            write_time = st.empty()
            write_time.info(f"Pomodoro started at: {hour[0]}")
        except:
            pass
        sel_options = st.empty()
        send_btn = st.empty()

        selections = {"Good": 1, "Bad": 2}