Пример #1
0
    def test_coalesce_widget_states(self):
        old_states = WidgetStates()

        _create_widget("old_set_trigger", old_states).trigger_value = True
        _create_widget("old_unset_trigger", old_states).trigger_value = False
        _create_widget("missing_in_new", old_states).int_value = 123
        _create_widget("shape_changing_trigger",
                       old_states).trigger_value = True

        new_states = WidgetStates()

        _create_widget("old_set_trigger", new_states).trigger_value = False
        _create_widget("new_set_trigger", new_states).trigger_value = True
        _create_widget("added_in_new", new_states).int_value = 456
        _create_widget("shape_changing_trigger", new_states).int_value = 3

        widgets = WidgetStateManager()
        widgets.set_state(coalesce_widget_states(old_states, new_states))

        self.assertIsNone(widgets.get_widget_value("old_unset_trigger"))
        self.assertIsNone(widgets.get_widget_value("missing_in_new"))

        self.assertEqual(True, widgets.get_widget_value("old_set_trigger"))
        self.assertEqual(True, widgets.get_widget_value("new_set_trigger"))
        self.assertEqual(456, widgets.get_widget_value("added_in_new"))

        # Widgets that were triggers before, but no longer are, will *not*
        # be coalesced
        self.assertEqual(3, widgets.get_widget_value("shape_changing_trigger"))
Пример #2
0
    def test_reset_triggers(self):
        states = WidgetStates()
        widgets = WidgetStateManager()

        _create_widget("trigger", states).trigger_value = True
        _create_widget("int", states).int_value = 123
        widgets.set_state(states)

        self.assertEqual(True, widgets.get_widget_value("trigger"))
        self.assertEqual(123, widgets.get_widget_value("int"))

        widgets.reset_triggers()

        self.assertEqual(None, widgets.get_widget_value("trigger"))
        self.assertEqual(123, widgets.get_widget_value("int"))
Пример #3
0
    def test_values(self):
        states = WidgetStates()

        _create_widget("trigger", states).trigger_value = True
        _create_widget("bool", states).bool_value = True
        _create_widget("float", states).double_value = 0.5
        _create_widget("int", states).int_value = 123
        _create_widget("string", states).string_value = "howdy!"

        widgets = WidgetStateManager()
        widgets.set_state(states)

        self.assertEqual(True, widgets.get_widget_value("trigger"))
        self.assertEqual(True, widgets.get_widget_value("bool"))
        self.assertAlmostEqual(0.5, widgets.get_widget_value("float"))
        self.assertEqual(123, widgets.get_widget_value("int"))
        self.assertEqual("howdy!", widgets.get_widget_value("string"))
Пример #4
0
    def test_handle_save_request(self, _1):
        """Test that handle_save_request serializes files correctly."""
        # Create a ReportSession with some mocked bits
        rs = ReportSession(self.io_loop, "mock_report.py", "", UploadedFileManager())
        rs._report.report_id = "TestReportID"

        orig_ctx = get_report_ctx()
        ctx = ReportContext(
            "TestSessionID",
            rs._report.enqueue,
            "",
            WidgetStateManager(),
            UploadedFileManager(),
        )
        add_report_ctx(ctx=ctx)

        rs._scriptrunner = MagicMock()

        storage = MockStorage()
        rs._storage = storage

        # Send two deltas: empty and markdown
        st.empty()
        st.markdown("Text!")

        yield rs.handle_save_request(_create_mock_websocket())

        # Check the order of the received files. Manifest should be last.
        self.assertEqual(3, len(storage.files))
        self.assertEqual("reports/TestReportID/0.pb", storage.get_filename(0))
        self.assertEqual("reports/TestReportID/1.pb", storage.get_filename(1))
        self.assertEqual("reports/TestReportID/manifest.pb", storage.get_filename(2))

        # Check the manifest
        manifest = storage.get_message(2, StaticManifest)
        self.assertEqual("mock_report", manifest.name)
        self.assertEqual(2, manifest.num_messages)
        self.assertEqual(StaticManifest.DONE, manifest.server_status)

        # Check that the deltas we sent match messages in storage
        sent_messages = rs._report._master_queue._queue
        received_messages = [
            storage.get_message(0, ForwardMsg),
            storage.get_message(1, ForwardMsg),
        ]

        self.assertEqual(sent_messages, received_messages)

        add_report_ctx(ctx=orig_ctx)
Пример #5
0
    def setUp(self, override_root=True):
        self.report_queue = ReportQueue()
        self.override_root = override_root
        self.orig_report_ctx = None

        if self.override_root:
            self.orig_report_ctx = get_report_ctx()
            add_report_ctx(
                threading.current_thread(),
                ReportContext(
                    session_id="test session id",
                    enqueue=self.report_queue.enqueue,
                    query_string="",
                    widgets=WidgetStateManager(),
                    uploaded_file_mgr=UploadedFileManager(),
                ),
            )
Пример #6
0
    def test_set_page_config_immutable(self):
        """st.set_page_config must be called at most once"""

        fake_enqueue = lambda msg: None
        ctx = ReportContext(
            "TestSessionID",
            fake_enqueue,
            "",
            WidgetStateManager(),
            UploadedFileManager(),
        )

        msg = ForwardMsg()
        msg.page_config_changed.title = "foo"

        ctx.enqueue(msg)
        with self.assertRaises(StreamlitAPIException):
            ctx.enqueue(msg)
Пример #7
0
    def test_set_page_config_reset(self):
        """st.set_page_config should be allowed after a rerun"""

        fake_enqueue = lambda msg: None
        ctx = ReportContext(
            "TestSessionID",
            fake_enqueue,
            "",
            WidgetStateManager(),
            UploadedFileManager(),
        )

        msg = ForwardMsg()
        msg.page_config_changed.title = "foo"

        ctx.enqueue(msg)
        ctx.reset()
        try:
            ctx.enqueue(msg)
        except StreamlitAPIException:
            self.fail("set_page_config should have succeeded after reset!")
Пример #8
0
    def test_set_page_config_first(self):
        """st.set_page_config must be called before other st commands"""

        fake_enqueue = lambda msg: None
        ctx = ReportContext(
            "TestSessionID",
            fake_enqueue,
            "",
            WidgetStateManager(),
            UploadedFileManager(),
        )

        markdown_msg = ForwardMsg()
        markdown_msg.delta.new_element.markdown.body = "foo"

        msg = ForwardMsg()
        msg.page_config_changed.title = "foo"

        ctx.enqueue(markdown_msg)
        with self.assertRaises(StreamlitAPIException):
            ctx.enqueue(msg)
Пример #9
0
    def __init__(
        self,
        session_id,
        report,
        enqueue_forward_msg,
        client_state,
        request_queue,
        uploaded_file_mgr=None,
    ):
        """Initialize the ScriptRunner.

        (The ScriptRunner won't start executing until start() is called.)

        Parameters
        ----------
        session_id : str
            The ReportSession's id.

        report : Report
            The ReportSession's report.

        client_state : streamlit.proto.ClientState_pb2.ClientState
            The current state from the client (widgets and query params).

        request_queue : ScriptRequestQueue
            The queue that the ReportSession is publishing ScriptRequests to.
            ScriptRunner will continue running until the queue is empty,
            and then shut down.

        uploaded_file_mgr : UploadedFileManager
            The File manager to store the data uploaded by the file_uploader widget.

        """
        self._session_id = session_id
        self._report = report
        self._enqueue_forward_msg = enqueue_forward_msg
        self._request_queue = request_queue
        self._uploaded_file_mgr = uploaded_file_mgr

        self._client_state = client_state
        self._widgets = WidgetStateManager()
        self._widgets.set_state(client_state.widget_states)

        self.on_event = Signal(doc="""Emitted when a ScriptRunnerEvent occurs.

            This signal is *not* emitted on the same thread that the
            ScriptRunner was created on.

            Parameters
            ----------
            event : ScriptRunnerEvent

            exception : BaseException | None
                Our compile error. Set only for the
                SCRIPT_STOPPED_WITH_COMPILE_ERROR event.

            widget_states : streamlit.proto.WidgetStates_pb2.WidgetStates | None
                The ScriptRunner's final WidgetStates. Set only for the
                SHUTDOWN event.
            """)

        # Set to true when we process a SHUTDOWN request
        self._shutdown_requested = False

        # Set to true while we're executing. Used by
        # maybe_handle_execution_control_request.
        self._execing = False

        # This is initialized in start()
        self._script_thread = None
Пример #10
0
class ScriptRunner(object):
    def __init__(
        self,
        session_id,
        report,
        enqueue_forward_msg,
        client_state,
        request_queue,
        uploaded_file_mgr=None,
    ):
        """Initialize the ScriptRunner.

        (The ScriptRunner won't start executing until start() is called.)

        Parameters
        ----------
        session_id : str
            The ReportSession's id.

        report : Report
            The ReportSession's report.

        client_state : streamlit.proto.ClientState_pb2.ClientState
            The current state from the client (widgets and query params).

        request_queue : ScriptRequestQueue
            The queue that the ReportSession is publishing ScriptRequests to.
            ScriptRunner will continue running until the queue is empty,
            and then shut down.

        uploaded_file_mgr : UploadedFileManager
            The File manager to store the data uploaded by the file_uploader widget.

        """
        self._session_id = session_id
        self._report = report
        self._enqueue_forward_msg = enqueue_forward_msg
        self._request_queue = request_queue
        self._uploaded_file_mgr = uploaded_file_mgr

        self._client_state = client_state
        self._widgets = WidgetStateManager()
        self._widgets.set_state(client_state.widget_states)

        self.on_event = Signal(doc="""Emitted when a ScriptRunnerEvent occurs.

            This signal is *not* emitted on the same thread that the
            ScriptRunner was created on.

            Parameters
            ----------
            event : ScriptRunnerEvent

            exception : BaseException | None
                Our compile error. Set only for the
                SCRIPT_STOPPED_WITH_COMPILE_ERROR event.

            widget_states : streamlit.proto.WidgetStates_pb2.WidgetStates | None
                The ScriptRunner's final WidgetStates. Set only for the
                SHUTDOWN event.
            """)

        # Set to true when we process a SHUTDOWN request
        self._shutdown_requested = False

        # Set to true while we're executing. Used by
        # maybe_handle_execution_control_request.
        self._execing = False

        # This is initialized in start()
        self._script_thread = None

    def __repr__(self) -> str:
        return util.repr_(self)

    def start(self):
        """Start a new thread to process the ScriptEventQueue.

        This must be called only once.

        """
        if self._script_thread is not None:
            raise Exception("ScriptRunner was already started")

        self._script_thread = ReportThread(
            session_id=self._session_id,
            enqueue=self._enqueue_forward_msg,
            query_string=self._client_state.query_string,
            widgets=self._widgets,
            uploaded_file_mgr=self._uploaded_file_mgr,
            target=self._process_request_queue,
            name="ScriptRunner.scriptThread",
        )
        self._script_thread.start()

    def _process_request_queue(self):
        """Process the ScriptRequestQueue and then exits.

        This is run in a separate thread.

        """
        LOGGER.debug("Beginning script thread")

        while not self._shutdown_requested and self._request_queue.has_request:
            request, data = self._request_queue.dequeue()
            if request == ScriptRequest.STOP:
                LOGGER.debug("Ignoring STOP request while not running")
            elif request == ScriptRequest.SHUTDOWN:
                LOGGER.debug("Shutting down")
                self._shutdown_requested = True
            elif request == ScriptRequest.RERUN:
                self._run_script(data)
            else:
                raise RuntimeError("Unrecognized ScriptRequest: %s" % request)

        # Send a SHUTDOWN event before exiting. This includes the widget values
        # as they existed after our last successful script run, which the
        # ReportSession will pass on to the next ScriptRunner that gets
        # created.
        client_state = ClientState()
        client_state.query_string = self._client_state.query_string
        self._widgets.marshall(client_state)
        self.on_event.send(ScriptRunnerEvent.SHUTDOWN,
                           client_state=client_state)

    def _is_in_script_thread(self):
        """True if the calling function is running in the script thread"""
        return self._script_thread == threading.current_thread()

    def maybe_handle_execution_control_request(self):
        if not self._is_in_script_thread():
            # We can only handle execution_control_request if we're on the
            # script execution thread. However, it's possible for deltas to
            # be enqueued (and, therefore, for this function to be called)
            # in separate threads, so we check for that here.
            return

        if not self._execing:
            # If the _execing flag is not set, we're not actually inside
            # an exec() call. This happens when our script exec() completes,
            # we change our state to STOPPED, and a statechange-listener
            # enqueues a new ForwardEvent
            return

        # Pop the next request from our queue.
        request, data = self._request_queue.dequeue()
        if request is None:
            return

        LOGGER.debug("Received ScriptRequest: %s", request)
        if request == ScriptRequest.STOP:
            raise StopException()
        elif request == ScriptRequest.SHUTDOWN:
            self._shutdown_requested = True
            raise StopException()
        elif request == ScriptRequest.RERUN:
            raise RerunException(data)
        else:
            raise RuntimeError("Unrecognized ScriptRequest: %s" % request)

    def _install_tracer(self):
        """Install function that runs before each line of the script."""
        def trace_calls(frame, event, arg):
            self.maybe_handle_execution_control_request()
            return trace_calls

        # Python interpreters are not required to implement sys.settrace.
        if hasattr(sys, "settrace"):
            sys.settrace(trace_calls)

    @contextmanager
    def _set_execing_flag(self):
        """A context for setting the ScriptRunner._execing flag.

        Used by maybe_handle_execution_control_request to ensure that
        we only handle requests while we're inside an exec() call
        """
        if self._execing:
            raise RuntimeError("Nested set_execing_flag call")
        self._execing = True
        try:
            yield
        finally:
            self._execing = False

    def _run_script(self, rerun_data):
        """Run our script.

        Parameters
        ----------
        rerun_data: RerunData
            The RerunData to use.

        """
        assert self._is_in_script_thread()

        LOGGER.debug("Running script %s", rerun_data)

        # Reset DeltaGenerators, widgets, media files.
        media_file_manager.clear_session_files()

        ctx = get_report_ctx()
        if ctx is None:
            # This should never be possible on the script_runner thread.
            raise RuntimeError(
                "ScriptRunner thread has a null ReportContext. Something has gone very wrong!"
            )

        ctx.reset(query_string=rerun_data.query_string)

        self.on_event.send(ScriptRunnerEvent.SCRIPT_STARTED)

        # Compile the script. Any errors thrown here will be surfaced
        # to the user via a modal dialog in the frontend, and won't result
        # in their previous report disappearing.

        try:
            with source_util.open_python_file(self._report.script_path) as f:
                filebody = f.read()

            if config.get_option("runner.magicEnabled"):
                filebody = magic.add_magic(filebody, self._report.script_path)

            code = compile(
                filebody,
                # Pass in the file path so it can show up in exceptions.
                self._report.script_path,
                # We're compiling entire blocks of Python, so we need "exec"
                # mode (as opposed to "eval" or "single").
                mode="exec",
                # Don't inherit any flags or "future" statements.
                flags=0,
                dont_inherit=1,
                # Use the default optimization options.
                optimize=-1,
            )

        except BaseException as e:
            # We got a compile error. Send an error event and bail immediately.
            LOGGER.debug("Fatal script error: %s" % e)
            self.on_event.send(
                ScriptRunnerEvent.SCRIPT_STOPPED_WITH_COMPILE_ERROR,
                exception=e)
            return

        # If we get here, we've successfully compiled our script. The next step
        # is to run it. Errors thrown during execution will be shown to the
        # user as ExceptionElements.

        # Update the Widget object with the new widget_states.
        # (The ReportContext has a reference to this object, so we just update it in-place)
        if rerun_data.widget_states is not None:
            self._widgets.set_state(rerun_data.widget_states)

        if config.get_option("runner.installTracer"):
            self._install_tracer()

        # This will be set to a RerunData instance if our execution
        # is interrupted by a RerunException.
        rerun_with_data = None

        try:
            # Create fake module. This gives us a name global namespace to
            # execute the code in.
            module = _new_module("__main__")

            # Install the fake module as the __main__ module. This allows
            # the pickle module to work inside the user's code, since it now
            # can know the module where the pickled objects stem from.
            # IMPORTANT: This means we can't use "if __name__ == '__main__'" in
            # our code, as it will point to the wrong module!!!
            sys.modules["__main__"] = module

            # Add special variables to the module's globals dict.
            # Note: The following is a requirement for the CodeHasher to
            # work correctly. The CodeHasher is scoped to
            # files contained in the directory of __main__.__file__, which we
            # assume is the main script directory.
            module.__dict__["__file__"] = self._report.script_path

            with modified_sys_path(self._report), self._set_execing_flag():
                exec(code, module.__dict__)

        except RerunException as e:
            rerun_with_data = e.rerun_data

        except StopException:
            pass

        except BaseException as e:
            handle_uncaught_app_exception(e)

        finally:
            self._on_script_finished(ctx)

        # Use _log_if_error() to make sure we never ever ever stop running the
        # script without meaning to.
        _log_if_error(_clean_problem_modules)

        if rerun_with_data is not None:
            self._run_script(rerun_with_data)

    def _on_script_finished(self, ctx: ReportContext) -> None:
        """Called when our script finishes executing, even if it finished
        early with an exception. We perform post-run cleanup here.
        """
        self._widgets.reset_triggers()
        self._widgets.cull_nonexistent(ctx.widget_ids_this_run.items())
        # Signal that the script has finished. (We use SCRIPT_STOPPED_WITH_SUCCESS
        # even if we were stopped with an exception.)
        self.on_event.send(ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS)
        # Delete expired files now that the script has run and files in use
        # are marked as active.
        media_file_manager.del_expired_files()
Пример #11
0
    def test_rerun_data_coalescing(self):
        """Test that multiple RERUN requests get coalesced with
        expected values.

        (This is similar to widgets_test.test_coalesce_widget_states -
        it's testing the same thing, but through the ScriptEventQueue
        interface.)
        """
        queue = ScriptRequestQueue()

        states = WidgetStates()
        _create_widget("trigger", states).trigger_value = True
        _create_widget("int", states).int_value = 123

        queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=states))

        states = WidgetStates()
        _create_widget("trigger", states).trigger_value = False
        _create_widget("int", states).int_value = 456

        queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=states))

        event, data = queue.dequeue()
        self.assertEqual(event, ScriptRequest.RERUN)

        widgets = WidgetStateManager()
        widgets.set_state(data.widget_states)

        # Coalesced triggers should be True if either the old or
        # new value was True
        self.assertEqual(True, widgets.get_widget_value("trigger"))

        # Other widgets should have their newest value
        self.assertEqual(456, widgets.get_widget_value("int"))

        # We should have no more events
        self.assertEqual((None, None), queue.dequeue(),
                         "Expected empty event queue")

        # Test that we can coalesce if previous widget state is None
        queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=None))
        queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=None))

        states = WidgetStates()
        _create_widget("int", states).int_value = 789

        queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=states))

        event, data = queue.dequeue()
        widgets = WidgetStateManager()
        widgets.set_state(data.widget_states)

        self.assertEqual(event, ScriptRequest.RERUN)
        self.assertEqual(789, widgets.get_widget_value("int"))

        # We should have no more events
        self.assertEqual((None, None), queue.dequeue(),
                         "Expected empty event queue")

        # Test that we can coalesce if our *new* widget state is None
        states = WidgetStates()
        _create_widget("int", states).int_value = 101112

        queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=states))

        queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=None))

        event, data = queue.dequeue()
        widgets = WidgetStateManager()
        widgets.set_state(data.widget_states)

        self.assertEqual(event, ScriptRequest.RERUN)
        self.assertEqual(101112, widgets.get_widget_value("int"))

        # We should have no more events
        self.assertEqual((None, None), queue.dequeue(),
                         "Expected empty event queue")