Ejemplo n.º 1
0
    def test_safe_widget_state(self):
        new_session_state = MagicMock()

        wstate = {"foo": "bar"}
        new_session_state.__getitem__.side_effect = wstate.__getitem__
        new_session_state.keys = lambda: {"foo", "baz"}
        self.session_state = SessionState({}, {}, new_session_state)

        assert self.session_state._safe_widget_state() == wstate
Ejemplo n.º 2
0
 def setUp(self):
     old_state = {"foo": "bar", "baz": "qux", "corge": "grault"}
     new_session_state = {"foo": "bar2"}
     new_widget_state = WStates(
         {
             "baz": Value("qux2"),
             f"{GENERATED_WIDGET_KEY_PREFIX}-foo-None": Value("bar"),
         }, )
     self.session_state = SessionState(old_state, new_session_state,
                                       new_widget_state)
Ejemplo n.º 3
0
    def test_set_widget_attrs_nonexistent(self):
        session_state = SessionState()
        session_state.set_metadata(create_metadata("fake_widget_id", ""))

        self.assertTrue(
            isinstance(
                session_state._new_widget_state.
                widget_metadata["fake_widget_id"],
                WidgetMetadata,
            ))
Ejemplo n.º 4
0
    def test_filtered_state_resilient_to_missing_metadata(self):
        old_state = {"foo": "bar", "corge": "grault"}
        new_session_state = {}
        new_widget_state = WStates(
            {f"{GENERATED_WIDGET_KEY_PREFIX}-baz": Serialized(None)}, )
        self.session_state = SessionState(old_state, new_session_state,
                                          new_widget_state)

        assert self.session_state.filtered_state == {
            "foo": "bar",
            "corge": "grault",
        }
Ejemplo n.º 5
0
    def __init__(self, ioloop, script_path, command_line,
                 uploaded_file_manager):
        """Initialize the ReportSession.

        Parameters
        ----------
        ioloop : tornado.ioloop.IOLoop
            The Tornado IOLoop that we're running within.

        script_path : str
            Path of the Python file from which this report is generated.

        command_line : str
            Command line as input by the user.

        uploaded_file_manager : UploadedFileManager
            The server's UploadedFileManager.

        """
        # Each ReportSession has a unique string ID.
        self.id = str(uuid.uuid4())

        self._ioloop = ioloop
        self._report = Report(script_path, command_line)
        self._uploaded_file_mgr = uploaded_file_manager

        self._state = ReportSessionState.REPORT_NOT_RUNNING

        # Need to remember the client state here because when a script reruns
        # due to the source code changing we need to pass in the previous client state.
        self._client_state = ClientState()

        self._local_sources_watcher = LocalSourcesWatcher(
            self._report, self._on_source_file_changed)
        self._stop_config_listener = config.on_config_parsed(
            self._on_source_file_changed, force_connect=True)
        self._storage = None
        self._maybe_reuse_previous_run = False
        self._run_on_save = config.get_option("server.runOnSave")

        # The ScriptRequestQueue is the means by which we communicate
        # with the active ScriptRunner.
        self._script_request_queue = ScriptRequestQueue()

        self._scriptrunner = None

        # This needs to be lazily imported to avoid a dependency cycle.
        from streamlit.state.session_state import SessionState

        self._session_state = SessionState()

        LOGGER.debug("ReportSession initialized (id=%s)", self.id)
Ejemplo n.º 6
0
def test_map_set_del_3837_regression():
    """A regression test for `test_map_set_del` that involves too much setup
    to conveniently use the hypothesis `example` decorator."""

    meta1 = stst.mock_metadata(
        "$$GENERATED_WIDGET_KEY-e3e70682-c209-4cac-629f-6fbed82c07cd-None", 0)
    meta2 = stst.mock_metadata(
        "$$GENERATED_WIDGET_KEY-f728b4fa-4248-5e3a-0a5d-2f346baa9455-0", 0)
    m = SessionState()
    m["0"] = 0
    m.set_unkeyed_widget(
        meta1,
        "$$GENERATED_WIDGET_KEY-e3e70682-c209-4cac-629f-6fbed82c07cd-None")
    m.compact_state()

    m.set_keyed_widget(
        meta2, "$$GENERATED_WIDGET_KEY-f728b4fa-4248-5e3a-0a5d-2f346baa9455-0",
        "0")
    key = "0"
    value1 = 0

    m[key] = value1
    l1 = len(m)
    del m[key]
    assert key not in m
    assert len(m) == l1 - 1
Ejemplo n.º 7
0
    def __init__(self, script_name):
        """Initializes the ScriptRunner for the given script_name"""
        # DeltaGenerator deltas will be enqueued into self.report_queue.
        self.report_queue = ReportQueue()

        def enqueue_fn(msg):
            self.report_queue.enqueue(msg)
            self.maybe_handle_execution_control_request()

        self.script_request_queue = ScriptRequestQueue()
        script_path = os.path.join(os.path.dirname(__file__), "test_data",
                                   script_name)

        super(TestScriptRunner, self).__init__(
            session_id="test session id",
            report=Report(script_path, "test command line"),
            enqueue_forward_msg=enqueue_fn,
            client_state=ClientState(),
            session_state=SessionState(),
            request_queue=self.script_request_queue,
        )

        # Accumulates uncaught exceptions thrown by our run thread.
        self.script_thread_exceptions = []

        # Accumulates all ScriptRunnerEvents emitted by us.
        self.events = []

        def record_event(event, **kwargs):
            self.events.append(event)

        self.on_event.connect(record_event, weak=False)
Ejemplo n.º 8
0
    def test_call_callbacks(self):
        """Test the call_callbacks method in 6 possible cases:

        1. A widget does not have a callback
        2. A widget's old and new values are equal, so the callback is not
           called.
        3. A widget's callback has no args provided.
        4. A widget's callback has just args provided.
        5. A widget's callback has just kwargs provided.
        6. A widget's callback has both args and kwargs provided.
        """
        prev_states = WidgetStates()
        _create_widget("trigger", prev_states).trigger_value = True
        _create_widget("bool", prev_states).bool_value = True
        _create_widget("bool2", prev_states).bool_value = True
        _create_widget("float", prev_states).double_value = 0.5
        _create_widget("int", prev_states).int_value = 123
        _create_widget("string", prev_states).string_value = "howdy!"

        session_state = SessionState()
        session_state.set_widgets_from_proto(prev_states)

        mock_callback = MagicMock()
        deserializer = lambda x, s: x

        callback_cases = [
            ("trigger", "trigger_value", None, None, None),
            ("bool", "bool_value", mock_callback, None, None),
            ("bool2", "bool_value", mock_callback, None, None),
            ("float", "double_value", mock_callback, (1,), None),
            ("int", "int_value", mock_callback, None, {"x": 2}),
            ("string", "string_value", mock_callback, (1,), {"x": 2}),
        ]
        for widget_id, value_type, callback, args, kwargs in callback_cases:
            session_state._set_widget_metadata(
                WidgetMetadata(
                    widget_id,
                    deserializer,
                    lambda x: x,
                    value_type=value_type,
                    callback=callback,
                    callback_args=args,
                    callback_kwargs=kwargs,
                )
            )

        states = WidgetStates()
        _create_widget("trigger", states).trigger_value = True
        _create_widget("bool", states).bool_value = True
        _create_widget("bool2", states).bool_value = False
        _create_widget("float", states).double_value = 1.5
        _create_widget("int", states).int_value = 321
        _create_widget("string", states).string_value = "!ydwoh"

        session_state.on_script_will_rerun(states)

        mock_callback.assert_has_calls([call(), call(1), call(x=2), call(1, x=2)])
Ejemplo n.º 9
0
    def test_get_keyed_widget_values(self):
        states = WidgetStates()
        _create_widget("trigger", states).trigger_value = True
        _create_widget("trigger2", states).trigger_value = True

        session_state = SessionState()
        session_state.set_widgets_from_proto(states)

        session_state._set_widget_metadata(
            create_metadata("trigger", "trigger_value", True))
        session_state._set_widget_metadata(
            create_metadata("trigger2", "trigger_value"))

        self.assertEqual(dict(session_state.values()), {"trigger": True})
def _create_mock_session_state(
        initial_state_values: Dict[str, Any]) -> SafeSessionState:
    """Return a new SafeSessionState instance populated with the
    given state values.
    """
    session_state = SessionState()
    for key, value in initial_state_values.items():
        session_state[key] = value
    return SafeSessionState(session_state)
Ejemplo n.º 11
0
    def test_reset_triggers(self):
        states = WidgetStates()
        session_state = SessionState()

        _create_widget("trigger", states).trigger_value = True
        _create_widget("int", states).int_value = 123
        session_state.set_widgets_from_proto(states)
        session_state._set_widget_metadata(
            WidgetMetadata("trigger", lambda x, s: x, None, "trigger_value"))
        session_state._set_widget_metadata(
            WidgetMetadata("int", lambda x, s: x, None, "int_value"))

        self.assertTrue(session_state["trigger"])
        self.assertEqual(123, session_state["int"])

        session_state._reset_triggers()

        self.assertFalse(session_state["trigger"])
        self.assertEqual(123, session_state["int"])
Ejemplo n.º 12
0
def _session_state(draw) -> SessionState:
    state = SessionState()
    new_state = draw(NEW_SESSION_STATE)
    for k, v in new_state.items():
        state[k] = v

    unkeyed_widgets = draw(
        hst.dictionaries(keys=UNKEYED_WIDGET_IDS, values=hst.integers()))
    for wid, v in unkeyed_widgets.items():
        state.register_widget(mock_metadata(wid, v), user_key=None)

    widget_key_val_triple = draw(
        hst.lists(hst.tuples(hst.uuids(), USER_KEY, hst.integers())))
    k_wids = {
        key: (as_keyed_widget_id(wid, key), val)
        for wid, key, val in widget_key_val_triple
    }
    for key, (wid, val) in k_wids.items():
        state.register_widget(mock_metadata(wid, val), user_key=key)

    if k_wids:
        session_state_widget_entries = draw(
            hst.dictionaries(
                keys=hst.sampled_from(list(k_wids.keys())),
                values=hst.integers(),
            ))
        for k, v in session_state_widget_entries.items():
            state[k] = v

    return state
Ejemplo n.º 13
0
def _session_state(draw):
    state = SessionState()
    new_state = draw(new_session_state)
    for k, v in new_state.items():
        state[k] = v

    unkeyed_widgets = draw(
        hst.dictionaries(keys=unkeyed_widget_ids, values=hst.integers()))
    for wid, v in unkeyed_widgets.items():
        state.set_unkeyed_widget(mock_metadata(wid, v), wid)

    widget_key_val_triple = draw(
        hst.lists(hst.tuples(hst.uuids(), user_key, hst.integers())))
    k_wids = {
        key: (as_keyed_widget_id(wid, key), val)
        for wid, key, val in widget_key_val_triple
    }
    for key, (wid, val) in k_wids.items():
        state.set_keyed_widget(mock_metadata(wid, val), wid, key)

    if k_wids:
        session_state_widget_entries = draw(
            hst.dictionaries(
                keys=hst.sampled_from(list(k_wids.keys())),
                values=hst.integers(),
            ))
        for k, v in session_state_widget_entries.items():
            state[k] = v

    return state
Ejemplo n.º 14
0
class LazySessionStateAttributeTests(unittest.TestCase):
    """Tests of LazySessionState attribute methods.

    Separate from the others to change patching. Test methods are individually
    patched to avoid issues with mutability.
    """
    def setUp(self):
        self.lazy_session_state = LazySessionState()

    @patch(
        "streamlit.state.session_state.get_session_state",
        return_value=SessionState(new_session_state={"foo": "bar"}),
    )
    def test_delattr(self, _):
        del self.lazy_session_state.foo
        assert "foo" not in self.lazy_session_state

    @patch(
        "streamlit.state.session_state.get_session_state",
        return_value=SessionState(new_session_state={"foo": "bar"}),
    )
    def test_getattr(self, _):
        assert self.lazy_session_state.foo == "bar"

    @patch(
        "streamlit.state.session_state.get_session_state",
        return_value=SessionState(new_session_state={"foo": "bar"}),
    )
    def test_getattr_error(self, _):
        with pytest.raises(AttributeError):
            del self.lazy_session_state.nonexistent

    @patch(
        "streamlit.state.session_state.get_session_state",
        return_value=SessionState(new_session_state={"foo": "bar"}),
    )
    def test_setattr(self, _):
        self.lazy_session_state.corge = "grault2"
        assert self.lazy_session_state.corge == "grault2"
Ejemplo n.º 15
0
    def test_handle_save_request(self, _1):
        """Test that handle_save_request serializes files correctly."""
        # Create a ReportSession with some mocked bits
        rs = ReportSession(
            self.io_loop, "mock_report.py", "", UploadedFileManager(), None
        )
        rs._report.report_id = "TestReportID"

        orig_ctx = get_report_ctx()
        ctx = ReportContext(
            "TestSessionID",
            rs._report.enqueue,
            "",
            SessionState(),
            UploadedFileManager(),
        )
        add_report_ctx(ctx=ctx)

        rs._scriptrunner = MagicMock()

        storage = MockStorage()
        rs._storage = storage

        # Send two deltas: empty and markdown
        st.empty()
        st.markdown("Text!")

        yield rs.handle_save_request(_create_mock_websocket())

        # Check the order of the received files. Manifest should be last.
        self.assertEqual(3, len(storage.files))
        self.assertEqual("reports/TestReportID/0.pb", storage.get_filename(0))
        self.assertEqual("reports/TestReportID/1.pb", storage.get_filename(1))
        self.assertEqual("reports/TestReportID/manifest.pb", storage.get_filename(2))

        # Check the manifest
        manifest = storage.get_message(2, StaticManifest)
        self.assertEqual("mock_report", manifest.name)
        self.assertEqual(2, manifest.num_messages)
        self.assertEqual(StaticManifest.DONE, manifest.server_status)

        # Check that the deltas we sent match messages in storage
        sent_messages = rs._report._master_queue._queue
        received_messages = [
            storage.get_message(0, ForwardMsg),
            storage.get_message(1, ForwardMsg),
        ]

        self.assertEqual(sent_messages, received_messages)

        add_report_ctx(ctx=orig_ctx)
Ejemplo n.º 16
0
    def test_marshall_excludes_widgets_without_state(self):
        widget_states = WidgetStates()
        _create_widget("trigger", widget_states).trigger_value = True

        session_state = SessionState()
        session_state.set_widgets_from_proto(widget_states)
        session_state._set_widget_metadata(
            WidgetMetadata("other_widget", lambda x, s: x, None,
                           "trigger_value", True))

        widgets = session_state.get_widget_states()

        self.assertEqual(len(widgets), 1)
        self.assertEqual(widgets[0].id, "trigger")
Ejemplo n.º 17
0
    def setUp(self, override_root=True):
        self.forward_msg_queue = ForwardMsgQueue()
        self.override_root = override_root
        self.orig_report_ctx = None
        self.new_script_run_ctx = ScriptRunContext(
            session_id="test session id",
            enqueue=self.forward_msg_queue.enqueue,
            query_string="",
            session_state=SessionState(),
            uploaded_file_mgr=UploadedFileManager(),
        )

        if self.override_root:
            self.orig_report_ctx = get_script_run_ctx()
            add_script_run_ctx(threading.current_thread(),
                               self.new_script_run_ctx)

        self.app_session = FakeAppSession()
Ejemplo n.º 18
0
    def test_set_page_config_immutable(self):
        """st.set_page_config must be called at most once"""

        fake_enqueue = lambda msg: None
        ctx = ScriptRunContext(
            "TestSessionID",
            fake_enqueue,
            "",
            SessionState(),
            UploadedFileManager(),
        )

        msg = ForwardMsg()
        msg.page_config_changed.title = "foo"

        ctx.enqueue(msg)
        with self.assertRaises(StreamlitAPIException):
            ctx.enqueue(msg)
Ejemplo n.º 19
0
    def setUp(self, override_root=True):
        self.report_queue = ReportQueue()
        self.override_root = override_root
        self.orig_report_ctx = None

        if self.override_root:
            self.orig_report_ctx = get_report_ctx()
            add_report_ctx(
                threading.current_thread(),
                ReportContext(
                    session_id="test session id",
                    enqueue=self.report_queue.enqueue,
                    query_string="",
                    session_state=SessionState(),
                    uploaded_file_mgr=UploadedFileManager(),
                ),
            )

        self.report_session = FakeReportSession()
Ejemplo n.º 20
0
    def __init__(self, script_name: str):
        """Initializes the ScriptRunner for the given script_name"""
        # DeltaGenerator deltas will be enqueued into self.forward_msg_queue.
        self.forward_msg_queue = ForwardMsgQueue()

        main_script_path = os.path.join(os.path.dirname(__file__), "test_data",
                                        script_name)

        super().__init__(
            session_id="test session id",
            session_data=SessionData(main_script_path, "test command line"),
            client_state=ClientState(),
            session_state=SessionState(),
            uploaded_file_mgr=UploadedFileManager(),
            initial_rerun_data=RerunData(),
        )

        # Accumulates uncaught exceptions thrown by our run thread.
        self.script_thread_exceptions: List[BaseException] = []

        # Accumulates all ScriptRunnerEvents emitted by us.
        self.events: List[ScriptRunnerEvent] = []
        self.event_data: List[Any] = []

        def record_event(sender: Optional[ScriptRunner],
                         event: ScriptRunnerEvent, **kwargs) -> None:
            # Assert that we're not getting unexpected `sender` params
            # from ScriptRunner.on_event
            assert (sender is None
                    or sender == self), "Unexpected ScriptRunnerEvent sender!"

            self.events.append(event)
            self.event_data.append(kwargs)

            # Send ENQUEUE_FORWARD_MSGs to our queue
            if event == ScriptRunnerEvent.ENQUEUE_FORWARD_MSG:
                forward_msg = kwargs["forward_msg"]
                self.forward_msg_queue.enqueue(forward_msg)

        self.on_event.connect(record_event, weak=False)
Ejemplo n.º 21
0
    def test_disallow_set_page_config_twice(self):
        """st.set_page_config cannot be called twice"""

        fake_enqueue = lambda msg: None
        ctx = ScriptRunContext(
            "TestSessionID",
            fake_enqueue,
            "",
            SessionState(),
            UploadedFileManager(),
        )

        ctx.on_script_start()

        msg = ForwardMsg()
        msg.page_config_changed.title = "foo"
        ctx.enqueue(msg)

        with self.assertRaises(StreamlitAPIException):
            same_msg = ForwardMsg()
            same_msg.page_config_changed.title = "bar"
            ctx.enqueue(same_msg)
Ejemplo n.º 22
0
    def test_set_page_config_reset(self):
        """st.set_page_config should be allowed after a rerun"""

        fake_enqueue = lambda msg: None
        ctx = ScriptRunContext(
            "TestSessionID",
            fake_enqueue,
            "",
            SessionState(),
            UploadedFileManager(),
        )

        ctx.on_script_start()

        msg = ForwardMsg()
        msg.page_config_changed.title = "foo"

        ctx.enqueue(msg)
        ctx.reset()
        try:
            ctx.on_script_start()
            ctx.enqueue(msg)
        except StreamlitAPIException:
            self.fail("set_page_config should have succeeded after reset!")
Ejemplo n.º 23
0
    def test_set_page_config_first(self):
        """st.set_page_config must be called before other st commands
        when the script has been marked as started"""

        fake_enqueue = lambda msg: None
        ctx = ScriptRunContext(
            "TestSessionID",
            fake_enqueue,
            "",
            SessionState(),
            UploadedFileManager(),
        )

        ctx.on_script_start()

        markdown_msg = ForwardMsg()
        markdown_msg.delta.new_element.markdown.body = "foo"

        msg = ForwardMsg()
        msg.page_config_changed.title = "foo"

        ctx.enqueue(markdown_msg)
        with self.assertRaises(StreamlitAPIException):
            ctx.enqueue(msg)
Ejemplo n.º 24
0
    def test_cached_st_function_warning(self, _, cache_decorator, call_stack):
        """Ensure we properly warn when st.foo functions are called
        inside a cached function.
        """
        forward_msg_queue = ForwardMsgQueue()
        orig_report_ctx = get_script_run_ctx()
        add_script_run_ctx(
            threading.current_thread(),
            ScriptRunContext(
                session_id="test session id",
                enqueue=forward_msg_queue.enqueue,
                query_string="",
                session_state=SessionState(),
                uploaded_file_mgr=None,
            ),
        )
        with patch.object(call_stack,
                          "_show_cached_st_function_warning") as warning:
            st.text("foo")
            warning.assert_not_called()

            @cache_decorator
            def cached_func():
                st.text("Inside cached func")

            cached_func()
            warning.assert_called_once()

            warning.reset_mock()

            # Make sure everything got reset properly
            st.text("foo")
            warning.assert_not_called()

            # Test warning suppression
            @cache_decorator(suppress_st_warning=True)
            def suppressed_cached_func():
                st.text("No warnings here!")

            suppressed_cached_func()

            warning.assert_not_called()

            # Test nested st.cache functions
            @cache_decorator
            def outer():
                @cache_decorator
                def inner():
                    st.text("Inside nested cached func")

                return inner()

            outer()
            warning.assert_called_once()

            warning.reset_mock()

            # Test st.cache functions that raise errors
            with self.assertRaises(RuntimeError):

                @cache_decorator
                def cached_raise_error():
                    st.text("About to throw")
                    raise RuntimeError("avast!")

                cached_raise_error()

            warning.assert_called_once()
            warning.reset_mock()

            # Make sure everything got reset properly
            st.text("foo")
            warning.assert_not_called()

            # Test st.cache functions with widgets
            @cache_decorator
            def cached_widget():
                st.button("Press me!")

            cached_widget()

            warning.assert_called_once()
            warning.reset_mock()

            # Make sure everything got reset properly
            st.text("foo")
            warning.assert_not_called()

            add_script_run_ctx(threading.current_thread(), orig_report_ctx)
Ejemplo n.º 25
0
 def test_get_prev_widget_value_nonexistent(self):
     session_state = SessionState()
     self.assertRaises(KeyError, lambda: session_state["fake_widget_id"])
    def test_rerun_data_coalescing(self):
        """Test that multiple RERUN requests get coalesced with
        expected values.

        (This is similar to widgets_test.test_coalesce_widget_states -
        it's testing the same thing, but through the ScriptEventQueue
        interface.)
        """
        queue = ScriptRequestQueue()
        session_state = SessionState()

        states = WidgetStates()
        _create_widget("trigger", states).trigger_value = True
        _create_widget("int", states).int_value = 123

        queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=states))

        states = WidgetStates()
        _create_widget("trigger", states).trigger_value = False
        _create_widget("int", states).int_value = 456

        session_state.set_metadata(
            WidgetMetadata("trigger", lambda x, s: x, None, "trigger_value"))
        session_state.set_metadata(
            WidgetMetadata("int", lambda x, s: x, lambda x: x, "int_value"))

        queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=states))

        event, data = queue.dequeue()
        self.assertEqual(event, ScriptRequest.RERUN)

        session_state.set_widgets_from_proto(data.widget_states)

        # Coalesced triggers should be True if either the old or
        # new value was True
        self.assertEqual(True, session_state.get("trigger"))

        # Other widgets should have their newest value
        self.assertEqual(456, session_state.get("int"))

        # We should have no more events
        self.assertEqual((None, None), queue.dequeue(),
                         "Expected empty event queue")

        # Test that we can coalesce if previous widget state is None
        queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=None))
        queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=None))

        states = WidgetStates()
        _create_widget("int", states).int_value = 789

        queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=states))

        event, data = queue.dequeue()
        session_state.set_widgets_from_proto(data.widget_states)

        self.assertEqual(event, ScriptRequest.RERUN)
        self.assertEqual(789, session_state.get("int"))

        # We should have no more events
        self.assertEqual((None, None), queue.dequeue(),
                         "Expected empty event queue")

        # Test that we can coalesce if our *new* widget state is None
        states = WidgetStates()
        _create_widget("int", states).int_value = 101112

        queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=states))

        queue.enqueue(ScriptRequest.RERUN, RerunData(widget_states=None))

        event, data = queue.dequeue()
        session_state.set_widgets_from_proto(data.widget_states)

        self.assertEqual(event, ScriptRequest.RERUN)
        self.assertEqual(101112, session_state.get("int"))

        # We should have no more events
        self.assertEqual((None, None), queue.dequeue(),
                         "Expected empty event queue")
Ejemplo n.º 27
0
    def test_get(self):
        states = WidgetStates()

        _create_widget("trigger", states).trigger_value = True
        _create_widget("bool", states).bool_value = True
        _create_widget("float", states).double_value = 0.5
        _create_widget("int", states).int_value = 123
        _create_widget("string", states).string_value = "howdy!"

        session_state = SessionState()
        session_state.set_widgets_from_proto(states)

        session_state._set_widget_metadata(
            create_metadata("trigger", "trigger_value"))
        session_state._set_widget_metadata(
            create_metadata("bool", "bool_value"))
        session_state._set_widget_metadata(
            create_metadata("float", "double_value"))
        session_state._set_widget_metadata(create_metadata("int", "int_value"))
        session_state._set_widget_metadata(
            create_metadata("string", "string_value"))

        self.assertEqual(True, session_state["trigger"])
        self.assertEqual(True, session_state["bool"])
        self.assertAlmostEqual(0.5, session_state["float"])
        self.assertEqual(123, session_state["int"])
        self.assertEqual("howdy!", session_state["string"])
Ejemplo n.º 28
0
 def __init__(self):
     self._session_state = SessionState()
Ejemplo n.º 29
0
class AppSession:
    """
    Contains session data for a single "user" of an active app
    (that is, a connected browser tab).

    Each AppSession has its own SessionData, root DeltaGenerator, ScriptRunner,
    and widget state.

    An AppSession is attached to each thread involved in running its script.

    """
    def __init__(
        self,
        ioloop: tornado.ioloop.IOLoop,
        session_data: SessionData,
        uploaded_file_manager: UploadedFileManager,
        message_enqueued_callback: Optional[Callable[[], None]],
        local_sources_watcher: LocalSourcesWatcher,
    ):
        """Initialize the AppSession.

        Parameters
        ----------
        ioloop : tornado.ioloop.IOLoop
            The Tornado IOLoop that we're running within.

        session_data : SessionData
            Object storing parameters related to running a script

        uploaded_file_manager : UploadedFileManager
            The server's UploadedFileManager.

        message_enqueued_callback : Callable[[], None]
            After enqueuing a message, this callable notification will be invoked.

        local_sources_watcher: LocalSourcesWatcher
            The file watcher that lets the session know local files have changed.

        """
        # Each AppSession has a unique string ID.
        self.id = str(uuid.uuid4())

        self._ioloop = ioloop
        self._session_data = session_data
        self._uploaded_file_mgr = uploaded_file_manager
        self._message_enqueued_callback = message_enqueued_callback

        self._state = AppSessionState.APP_NOT_RUNNING

        # Need to remember the client state here because when a script reruns
        # due to the source code changing we need to pass in the previous client state.
        self._client_state = ClientState()

        self._local_sources_watcher = local_sources_watcher
        self._local_sources_watcher.register_file_change_callback(
            self._on_source_file_changed)
        self._stop_config_listener = config.on_config_parsed(
            self._on_source_file_changed, force_connect=True)

        # The script should rerun when the `secrets.toml` file has been changed.
        secrets._file_change_listener.connect(self._on_secrets_file_changed)

        self._run_on_save = config.get_option("server.runOnSave")

        # The ScriptRequestQueue is the means by which we communicate
        # with the active ScriptRunner.
        self._script_request_queue = ScriptRequestQueue()

        self._scriptrunner: Optional[ScriptRunner] = None

        # This needs to be lazily imported to avoid a dependency cycle.
        from streamlit.state.session_state import SessionState

        self._session_state = SessionState()

        LOGGER.debug("AppSession initialized (id=%s)", self.id)

    def flush_browser_queue(self) -> List[ForwardMsg]:
        """Clear the forward message queue and return the messages it contained.

        The Server calls this periodically to deliver new messages
        to the browser connected to this app.

        Returns
        -------
        list[ForwardMsg]
            The messages that were removed from the queue and should
            be delivered to the browser.

        """
        return self._session_data.flush_browser_queue()

    def shutdown(self) -> None:
        """Shut down the AppSession.

        It's an error to use a AppSession after it's been shut down.

        """
        if self._state != AppSessionState.SHUTDOWN_REQUESTED:
            LOGGER.debug("Shutting down (id=%s)", self.id)
            # Clear any unused session files in upload file manager and media
            # file manager
            self._uploaded_file_mgr.remove_session_files(self.id)
            in_memory_file_manager.clear_session_files(self.id)
            in_memory_file_manager.del_expired_files()

            # Shut down the ScriptRunner, if one is active.
            # self._state must not be set to SHUTDOWN_REQUESTED until
            # after this is called.
            if self._scriptrunner is not None:
                self._enqueue_script_request(ScriptRequest.SHUTDOWN)

            self._state = AppSessionState.SHUTDOWN_REQUESTED
            self._local_sources_watcher.close()
            if self._stop_config_listener is not None:
                self._stop_config_listener()
            secrets._file_change_listener.disconnect(
                self._on_secrets_file_changed)

    def enqueue(self, msg: ForwardMsg) -> None:
        """Enqueue a new ForwardMsg to our browser queue.

        This can be called on both the main thread and a ScriptRunner
        run thread.

        Parameters
        ----------
        msg : ForwardMsg
            The message to enqueue

        """
        if not config.get_option("client.displayEnabled"):
            return

        # Avoid having two maybe_handle_execution_control_request running on
        # top of each other when tracer is installed. This leads to a lock
        # contention.
        if not config.get_option("runner.installTracer"):
            # If we have an active ScriptRunner, signal that it can handle an
            # execution control request. (Copy the scriptrunner reference to
            # avoid it being unset from underneath us, as this function can be
            # called outside the main thread.)
            scriptrunner = self._scriptrunner

            if scriptrunner is not None:
                scriptrunner.maybe_handle_execution_control_request()

        self._session_data.enqueue(msg)
        if self._message_enqueued_callback:
            self._message_enqueued_callback()

    def enqueue_exception(self, e: BaseException) -> None:
        """Enqueue an Exception message."""
        # This does a few things:
        # 1) Clears the current app in the browser.
        # 2) Marks the current app as "stopped" in the browser.
        # 3) HACK: Resets any script params that may have been broken (e.g. the
        # command-line when rerunning with wrong argv[0])
        self._on_scriptrunner_event(
            ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS)
        self._on_scriptrunner_event(ScriptRunnerEvent.SCRIPT_STARTED)
        self._on_scriptrunner_event(
            ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS)

        msg = ForwardMsg()
        exception_utils.marshall(msg.delta.new_element.exception, e)

        self.enqueue(msg)

    def request_rerun(self, client_state: Optional[ClientState]) -> None:
        """Signal that we're interested in running the script.

        If the script is not already running, it will be started immediately.
        Otherwise, a rerun will be requested.

        Parameters
        ----------
        client_state : streamlit.proto.ClientState_pb2.ClientState | None
            The ClientState protobuf to run the script with, or None
            to use previous client state.

        """
        if client_state:
            rerun_data = RerunData(client_state.query_string,
                                   client_state.widget_states)
        else:
            rerun_data = RerunData()

        self._enqueue_script_request(ScriptRequest.RERUN, rerun_data)

    @property
    def session_state(self) -> "SessionState":
        return self._session_state

    def _on_source_file_changed(self) -> None:
        """One of our source files changed. Schedule a rerun if appropriate."""
        if self._run_on_save:
            self.request_rerun(self._client_state)
        else:
            self._enqueue_file_change_message()

    def _on_secrets_file_changed(self, _) -> None:
        """Called when `secrets._file_change_listener` emits a Signal."""

        # NOTE: At the time of writing, this function only calls `_on_source_file_changed`.
        # The reason behind creating this function instead of just passing `_on_source_file_changed`
        # to `connect` / `disconnect` directly is that every function that is passed to `connect` / `disconnect`
        # must have at least one argument for `sender` (in this case we don't really care about it, thus `_`),
        # and introducing an unnecessary argument to `_on_source_file_changed` just for this purpose sounded finicky.
        self._on_source_file_changed()

    def _clear_queue(self) -> None:
        self._session_data.clear()

    def _on_scriptrunner_event(
        self,
        event: ScriptRunnerEvent,
        exception: Optional[BaseException] = None,
        client_state: Optional[ClientState] = None,
    ) -> None:
        """Called when our ScriptRunner emits an event.

        This is *not* called on the main thread.

        Parameters
        ----------
        event : ScriptRunnerEvent

        exception : BaseException | None
            An exception thrown during compilation. Set only for the
            SCRIPT_STOPPED_WITH_COMPILE_ERROR event.

        client_state : streamlit.proto.ClientState_pb2.ClientState | None
            The ScriptRunner's final ClientState. Set only for the
            SHUTDOWN event.

        """
        LOGGER.debug("OnScriptRunnerEvent: %s", event)

        prev_state = self._state

        if event == ScriptRunnerEvent.SCRIPT_STARTED:
            if self._state != AppSessionState.SHUTDOWN_REQUESTED:
                self._state = AppSessionState.APP_IS_RUNNING

            self._clear_queue()
            self._enqueue_new_session_message()

        elif (event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
              or event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_COMPILE_ERROR):

            if self._state != AppSessionState.SHUTDOWN_REQUESTED:
                self._state = AppSessionState.APP_NOT_RUNNING

            script_succeeded = event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS

            self._enqueue_script_finished_message(
                ForwardMsg.FINISHED_SUCCESSFULLY if script_succeeded else
                ForwardMsg.FINISHED_WITH_COMPILE_ERROR)

            if script_succeeded:
                # When a script completes successfully, we update our
                # LocalSourcesWatcher to account for any source code changes
                # that change which modules should be watched. (This is run on
                # the main thread, because LocalSourcesWatcher is not
                # thread safe.)
                self._ioloop.spawn_callback(
                    self._local_sources_watcher.update_watched_modules)
            else:
                msg = ForwardMsg()
                exception_utils.marshall(
                    msg.session_event.script_compilation_exception, exception)
                self.enqueue(msg)

        elif event == ScriptRunnerEvent.SHUTDOWN:
            # When ScriptRunner shuts down, update our local reference to it,
            # and check to see if we need to spawn a new one. (This is run on
            # the main thread.)

            assert (
                client_state
                is not None), "client_state must be set for the SHUTDOWN event"

            if self._state == AppSessionState.SHUTDOWN_REQUESTED:
                # Only clear media files if the script is done running AND the
                # session is actually shutting down.
                in_memory_file_manager.clear_session_files(self.id)

            def on_shutdown():
                # We assert above that this is non-null
                self._client_state = cast(ClientState, client_state)

                self._scriptrunner = None
                # Because a new ScriptEvent could have been enqueued while the
                # scriptrunner was shutting down, we check to see if we should
                # create a new one. (Otherwise, a newly-enqueued ScriptEvent
                # won't be processed until another event is enqueued.)
                self._maybe_create_scriptrunner()

            self._ioloop.spawn_callback(on_shutdown)

        # Send a message if our run state changed
        app_was_running = prev_state == AppSessionState.APP_IS_RUNNING
        app_is_running = self._state == AppSessionState.APP_IS_RUNNING
        if app_is_running != app_was_running:
            self._enqueue_session_state_changed_message()

    def _enqueue_session_state_changed_message(self) -> None:
        msg = ForwardMsg()
        msg.session_state_changed.run_on_save = self._run_on_save
        msg.session_state_changed.script_is_running = (
            self._state == AppSessionState.APP_IS_RUNNING)
        self.enqueue(msg)

    def _enqueue_file_change_message(self) -> None:
        LOGGER.debug("Enqueuing script_changed message (id=%s)", self.id)
        msg = ForwardMsg()
        msg.session_event.script_changed_on_disk = True
        self.enqueue(msg)

    def _enqueue_new_session_message(self) -> None:
        msg = ForwardMsg()

        msg.new_session.script_run_id = _generate_scriptrun_id()
        msg.new_session.name = self._session_data.name
        msg.new_session.script_path = self._session_data.script_path

        _populate_config_msg(msg.new_session.config)
        _populate_theme_msg(msg.new_session.custom_theme)

        # Immutable session data. We send this every time a new session is
        # started, to avoid having to track whether the client has already
        # received it. It does not change from run to run; it's up to the
        # to perform one-time initialization only once.
        imsg = msg.new_session.initialize

        _populate_user_info_msg(imsg.user_info)

        imsg.environment_info.streamlit_version = __version__
        imsg.environment_info.python_version = ".".join(
            map(str, sys.version_info))

        imsg.session_state.run_on_save = self._run_on_save
        imsg.session_state.script_is_running = (
            self._state == AppSessionState.APP_IS_RUNNING)

        imsg.command_line = self._session_data.command_line
        imsg.session_id = self.id

        self.enqueue(msg)

    def _enqueue_script_finished_message(
            self, status: "ForwardMsg.ScriptFinishedStatus.ValueType") -> None:
        """Enqueue a script_finished ForwardMsg."""
        msg = ForwardMsg()
        msg.script_finished = status
        self.enqueue(msg)

    def handle_git_information_request(self) -> None:
        msg = ForwardMsg()

        try:
            from streamlit.git_util import GitRepo

            repo = GitRepo(self._session_data.script_path)

            repo_info = repo.get_repo_info()
            if repo_info is None:
                return

            repository_name, branch, module = repo_info

            msg.git_info_changed.repository = repository_name
            msg.git_info_changed.branch = branch
            msg.git_info_changed.module = module

            msg.git_info_changed.untracked_files[:] = repo.untracked_files
            msg.git_info_changed.uncommitted_files[:] = repo.uncommitted_files

            if repo.is_head_detached:
                msg.git_info_changed.state = GitInfo.GitStates.HEAD_DETACHED
            elif len(repo.ahead_commits) > 0:
                msg.git_info_changed.state = GitInfo.GitStates.AHEAD_OF_REMOTE
            else:
                msg.git_info_changed.state = GitInfo.GitStates.DEFAULT

            self.enqueue(msg)
        except Exception as e:
            # Users may never even install Git in the first place, so this
            # error requires no action. It can be useful for debugging.
            LOGGER.debug("Obtaining Git information produced an error",
                         exc_info=e)

    def handle_rerun_script_request(self,
                                    client_state: Optional[ClientState] = None
                                    ) -> None:
        """Tell the ScriptRunner to re-run its script.

        Parameters
        ----------
        client_state : streamlit.proto.ClientState_pb2.ClientState | None
            The ClientState protobuf to run the script with, or None
            to use previous client state.

        """
        self.request_rerun(client_state)

    def handle_stop_script_request(self) -> None:
        """Tell the ScriptRunner to stop running its script."""
        self._enqueue_script_request(ScriptRequest.STOP)

    def handle_clear_cache_request(self) -> None:
        """Clear this app's cache.

        Because this cache is global, it will be cleared for all users.

        """
        legacy_caching.clear_cache()
        caching.memo.clear()
        caching.singleton.clear()
        self._session_state.clear_state()

    def handle_set_run_on_save_request(self, new_value: bool) -> None:
        """Change our run_on_save flag to the given value.

        The browser will be notified of the change.

        Parameters
        ----------
        new_value : bool
            New run_on_save value

        """
        self._run_on_save = new_value
        self._enqueue_session_state_changed_message()

    def _enqueue_script_request(self,
                                request: ScriptRequest,
                                data: Any = None) -> None:
        """Enqueue a ScriptEvent into our ScriptEventQueue.

        If a script thread is not already running, one will be created
        to handle the event.

        Parameters
        ----------
        request : ScriptRequest
            The type of request.

        data : Any
            Data associated with the request, if any.

        """
        if self._state == AppSessionState.SHUTDOWN_REQUESTED:
            LOGGER.warning("Discarding %s request after shutdown" % request)
            return

        self._script_request_queue.enqueue(request, data)
        self._maybe_create_scriptrunner()

    def _maybe_create_scriptrunner(self) -> None:
        """Create a new ScriptRunner if we have unprocessed script requests.

        This is called every time a ScriptRequest is enqueued, and also
        after a ScriptRunner shuts down, in case new requests were enqueued
        during its termination.

        This function should only be called on the main thread.

        """
        if (self._state == AppSessionState.SHUTDOWN_REQUESTED
                or self._scriptrunner is not None
                or not self._script_request_queue.has_request):
            return

        # Create the ScriptRunner, attach event handlers, and start it
        self._scriptrunner = ScriptRunner(
            session_id=self.id,
            session_data=self._session_data,
            enqueue_forward_msg=self.enqueue,
            client_state=self._client_state,
            request_queue=self._script_request_queue,
            session_state=self._session_state,
            uploaded_file_mgr=self._uploaded_file_mgr,
        )
        self._scriptrunner.on_event.connect(self._on_scriptrunner_event)
        self._scriptrunner.start()
Ejemplo n.º 30
0
    def __init__(
        self,
        ioloop: tornado.ioloop.IOLoop,
        session_data: SessionData,
        uploaded_file_manager: UploadedFileManager,
        message_enqueued_callback: Optional[Callable[[], None]],
        local_sources_watcher: LocalSourcesWatcher,
    ):
        """Initialize the AppSession.

        Parameters
        ----------
        ioloop : tornado.ioloop.IOLoop
            The Tornado IOLoop that we're running within.

        session_data : SessionData
            Object storing parameters related to running a script

        uploaded_file_manager : UploadedFileManager
            The server's UploadedFileManager.

        message_enqueued_callback : Callable[[], None]
            After enqueuing a message, this callable notification will be invoked.

        local_sources_watcher: LocalSourcesWatcher
            The file watcher that lets the session know local files have changed.

        """
        # Each AppSession has a unique string ID.
        self.id = str(uuid.uuid4())

        self._ioloop = ioloop
        self._session_data = session_data
        self._uploaded_file_mgr = uploaded_file_manager
        self._message_enqueued_callback = message_enqueued_callback

        self._state = AppSessionState.APP_NOT_RUNNING

        # Need to remember the client state here because when a script reruns
        # due to the source code changing we need to pass in the previous client state.
        self._client_state = ClientState()

        self._local_sources_watcher = local_sources_watcher
        self._local_sources_watcher.register_file_change_callback(
            self._on_source_file_changed)
        self._stop_config_listener = config.on_config_parsed(
            self._on_source_file_changed, force_connect=True)

        # The script should rerun when the `secrets.toml` file has been changed.
        secrets._file_change_listener.connect(self._on_secrets_file_changed)

        self._run_on_save = config.get_option("server.runOnSave")

        # The ScriptRequestQueue is the means by which we communicate
        # with the active ScriptRunner.
        self._script_request_queue = ScriptRequestQueue()

        self._scriptrunner: Optional[ScriptRunner] = None

        # This needs to be lazily imported to avoid a dependency cycle.
        from streamlit.state.session_state import SessionState

        self._session_state = SessionState()

        LOGGER.debug("AppSession initialized (id=%s)", self.id)