Esempio n. 1
0
    def test_simple_add_rows_with_clear_queue(self):
        """Test plain old add_rows after clearing the queue."""
        all_methods = self._get_unnamed_data_methods(
        ) + self._get_named_data_methods()

        for method in all_methods:
            # Create a new data-carrying element (e.g. st.dataframe)
            el = method(DATAFRAME)

            # Make sure it has 2 rows in it.
            df_proto = data_frame_proto._get_data_frame(
                self.get_delta_from_queue())
            num_rows = len(df_proto.data.cols[0].int64s.data)
            self.assertEqual(num_rows, 2)

            # This is what we're testing:
            self.report_queue.clear()
            el.add_rows(NEW_ROWS)

            # Make sure there are 3 rows in the delta that got appended.
            ar = self.get_delta_from_queue().add_rows
            num_rows = len(ar.data.data.cols[0].int64s.data)
            self.assertEqual(num_rows, 3)

            # Clear the queue so the next loop is like a brand new test.
            get_report_ctx().reset()
            self.report_queue.clear()
Esempio n. 2
0
    def test_with_index_add_rows(self):
        """Test plain old add_rows."""
        all_methods = self._get_unnamed_data_methods()

        for method in all_methods:
            # Create a new data-carrying element (e.g. st.dataframe)
            el = method(DATAFRAME_WITH_INDEX)

            # Make sure it has 2 rows in it.
            df_proto = data_frame_proto._get_data_frame(
                self.get_delta_from_queue())
            num_rows = len(df_proto.data.cols[0].int64s.data)
            self.assertEqual(num_rows, 2)

            # This is what we're testing:
            el.add_rows(NEW_ROWS_WITH_INDEX)

            # Make sure there are 2 rows in it now.
            df_proto = data_frame_proto._get_data_frame(
                self.get_delta_from_queue())
            num_rows = len(df_proto.data.cols[0].int64s.data)
            self.assertEqual(num_rows, 5)

            # Clear the queue so the next loop is like a brand new test.
            get_report_ctx().reset()
            self.report_queue.clear()
    def test_add_rows_fails_when_wrong_shape(self):
        """Test that add_rows raises error when input has wrong shape."""
        all_methods = self._get_unnamed_data_methods() + self._get_named_data_methods()

        for method in all_methods:
            # Create a new data-carrying element (e.g. st.dataframe)
            el = method(DATAFRAME)

            with self.assertRaises(ValueError):
                # This is what we're testing:
                el.add_rows(NEW_ROWS_WRONG_SHAPE)

            # Clear the queue so the next loop is like a brand new test.
            get_report_ctx().reset()
            self.report_queue.clear()
Esempio n. 4
0
 def _create_chart(self, type='line', height=0):
     empty_data = pd.DataFrame(columns=['loss', 'acc'])
     epoch_chart = Chart(empty_data, '%s_chart' % type, height=height)
     epoch_chart.y_axis(type='number',
                        y_axis_id="loss_axis",
                        allow_data_overflow="true")
     epoch_chart.y_axis(type='number',
                        orientation='right',
                        y_axis_id="acc_axis",
                        allow_data_overflow="true")
     epoch_chart.cartesian_grid(stroke_dasharray='3 3')
     epoch_chart.legend()
     getattr(epoch_chart, type)(type='monotone',
                                data_key='loss',
                                stroke='rgb(44,125,246)',
                                fill='rgb(44,125,246)',
                                dot="false",
                                y_axis_id='loss_axis')
     getattr(epoch_chart, type)(type='monotone',
                                data_key='acc',
                                stroke='#82ca9d',
                                fill='#82ca9d',
                                dot="false",
                                y_axis_id='acc_axis')
     # HACK: Use get_report_ctx() to grab root delta generator in an i9e
     # world.
     # TODO: Make this file not need _native_chart
     return get_report_ctx().main_dg._native_chart(epoch_chart)
Esempio n. 5
0
 def _create_chart(self, type="line", height=0):
     empty_data = pd.DataFrame(columns=["loss", "accuracy"])
     epoch_chart = Chart(empty_data, "%s_chart" % type, height=height)
     epoch_chart.y_axis(
         type="number", y_axis_id="loss_axis", allow_data_overflow="true"
     )
     epoch_chart.y_axis(
         type="number",
         orientation="right",
         y_axis_id="acc_axis",
         allow_data_overflow="true",
     )
     epoch_chart.cartesian_grid(stroke_dasharray="3 3")
     epoch_chart.legend()
     getattr(epoch_chart, type)(
         type="monotone",
         data_key="loss",
         stroke="rgb(44,125,246)",
         fill="rgb(44,125,246)",
         dot="false",
         y_axis_id="loss_axis",
     )
     getattr(epoch_chart, type)(
         type="monotone",
         data_key="accuracy",
         stroke="#82ca9d",
         fill="#82ca9d",
         dot="false",
         y_axis_id="acc_axis",
     )
     # HACK: Use get_report_ctx() to grab root delta generator in an i9e
     # world.
     # TODO: Make this file not need _native_chart
     return get_report_ctx().main_dg._native_chart(epoch_chart)
Esempio n. 6
0
def _get_session():
    session_id = get_report_ctx().session_id
    session_info = Server.get_current()._get_session_info(session_id)
    if session_info is None:
        raise RuntimeError("Couldn't get your Streamlit Session object.")

    return session_info.session
Esempio n. 7
0
def _get_session_id():
    """Semantic wrapper to retrieve current ReportSession ID."""
    ctx = get_report_ctx()
    if ctx is None:
        # This is only None when running "python myscript.py" rather than
        # "streamlit run myscript.py". In which case the session ID doesn't
        # matter and can just be a constant, as there's only ever "session".
        return "dontcare"
    else:
        return ctx.session_id
Esempio n. 8
0
    def test_add_rows_works_when_new_name(self):
        """Test add_rows with new named datasets."""

        for method in self._get_named_data_methods():
            # Create a new data-carrying element (e.g. st.dataframe)
            el = method(DATAFRAME)
            self.report_queue.clear()

            # This is what we're testing:
            el.add_rows(new_name=NEW_ROWS)

            # Make sure there are 3 rows in the delta that got appended.
            ar = self.get_delta_from_queue().add_rows
            num_rows = len(ar.data.data.cols[0].int64s.data)
            self.assertEqual(num_rows, 3)

            # Clear the queue so the next loop is like a brand new test.
            get_report_ctx().reset()
            self.report_queue.clear()
Esempio n. 9
0
def _get_full_session():
    session_id = get_report_ctx().session_id
    session_info = Server.get_current()._get_session_info(session_id)

    if session_info is None:
        raise RuntimeError("Couldn't get your Streamlit Session object.")

    # MODIFIED ORIGINAL _get_session CODE SO WE CAN ACCESS HEADERS FOR USER
    # return session_info.session
    return session_info
Esempio n. 10
0
def get_container_cursor(container):
    ctx = get_report_ctx()

    if ctx is None:
        return None

    if container in ctx.cursors:
        return ctx.cursors[container]

    cursor = RunningCursor()
    ctx.cursors[container] = cursor
    return cursor
    def test_named_add_rows(self):
        """Test add_rows with a named dataset."""
        for method in self._get_named_data_methods():
            # Create a new data-carrying element (e.g. st.dataframe)
            el = method(DATAFRAME)

            # Make sure it has 2 rows in it.
            df_proto = data_frame_proto._get_data_frame(self.get_delta_from_queue())
            num_rows = len(df_proto.data.cols[0].int64s.data)
            self.assertEqual(num_rows, 2)

            # This is what we're testing:
            el.add_rows(mydata1=NEW_ROWS)

            # Make sure there are 5 rows in it now.
            df_proto = data_frame_proto._get_data_frame(self.get_delta_from_queue())
            num_rows = len(df_proto.data.cols[0].int64s.data)
            self.assertEqual(num_rows, 5)

            # Clear the queue so the next loop is like a brand new test.
            get_report_ctx().reset()
            self.report_queue.clear()
Esempio n. 12
0
    def setUp(self, override_root=True):
        self.report_queue = ReportQueue()
        self.override_root = override_root
        self.orig_report_ctx = None

        if self.override_root:
            self.orig_report_ctx = get_report_ctx()
            add_report_ctx(
                threading.current_thread(),
                ReportContext(
                    enqueue=self.report_queue.enqueue,
                    widgets=Widgets(),
                    widget_ids_this_run=_WidgetIDSet(),
                    uploaded_file_mgr=UploadedFileManager(),
                ),
            )
Esempio n. 13
0
    def test_handle_save_request(self, _1):
        """Test that handle_save_request serializes files correctly."""
        # Create a ReportSession with some mocked bits
        rs = ReportSession(False, self.io_loop, "mock_report.py", "",
                           UploadedFileManager())
        rs._report.report_id = "TestReportID"

        orig_ctx = get_report_ctx()
        ctx = ReportContext("TestSessionID", rs._report.enqueue, None, None,
                            None)
        add_report_ctx(ctx=ctx)

        rs._scriptrunner = MagicMock()

        storage = MockStorage()
        rs._storage = storage

        # Send two deltas: empty and markdown
        st.empty()
        st.markdown("Text!")

        yield rs.handle_save_request(_create_mock_websocket())

        # Check the order of the received files. Manifest should be last.
        self.assertEqual(3, len(storage.files))
        self.assertEqual("reports/TestReportID/0.pb", storage.get_filename(0))
        self.assertEqual("reports/TestReportID/1.pb", storage.get_filename(1))
        self.assertEqual("reports/TestReportID/manifest.pb",
                         storage.get_filename(2))

        # Check the manifest
        manifest = storage.get_message(2, StaticManifest)
        self.assertEqual("mock_report", manifest.name)
        self.assertEqual(2, manifest.num_messages)
        self.assertEqual(StaticManifest.DONE, manifest.server_status)

        # Check that the deltas we sent match messages in storage
        sent_messages = rs._report._master_queue._queue
        received_messages = [
            storage.get_message(0, ForwardMsg),
            storage.get_message(1, ForwardMsg),
        ]

        self.assertEqual(sent_messages, received_messages)

        add_report_ctx(ctx=orig_ctx)
Esempio n. 14
0
 def wrapped_method(*args, **kwargs):
     ctx = get_report_ctx()
     dg = ctx.main_dg if ctx is not None else _NULL_DELTA_GENERATOR
     return method(dg, *args, **kwargs)
Esempio n. 15
0
    def _run_script(self, rerun_data):
        """Run our script.

        Parameters
        ----------
        rerun_data: RerunData
            The RerunData to use.

        """
        assert self._is_in_script_thread()

        LOGGER.debug("Running script %s", rerun_data)

        # Reset DeltaGenerators, widgets, media files.
        media_file_manager.clear_session_files()
        get_report_ctx().reset()

        self.on_event.send(ScriptRunnerEvent.SCRIPT_STARTED)

        # Compile the script. Any errors thrown here will be surfaced
        # to the user via a modal dialog in the frontend, and won't result
        # in their previous report disappearing.

        try:
            with source_util.open_python_file(self._report.script_path) as f:
                filebody = f.read()

            if config.get_option("runner.magicEnabled"):
                filebody = magic.add_magic(filebody, self._report.script_path)

            code = compile(
                filebody,
                # Pass in the file path so it can show up in exceptions.
                self._report.script_path,
                # We're compiling entire blocks of Python, so we need "exec"
                # mode (as opposed to "eval" or "single").
                mode="exec",
                # Don't inherit any flags or "future" statements.
                flags=0,
                dont_inherit=1,
                # Use the default optimization options.
                optimize=-1,
            )

        except BaseException as e:
            # We got a compile error. Send an error event and bail immediately.
            LOGGER.debug("Fatal script error: %s" % e)
            self.on_event.send(
                ScriptRunnerEvent.SCRIPT_STOPPED_WITH_COMPILE_ERROR,
                exception=e)
            return

        # If we get here, we've successfully compiled our script. The next step
        # is to run it. Errors thrown during execution will be shown to the
        # user as ExceptionElements.

        # Update the Widget singleton with the new widget_state
        if rerun_data.widget_state is not None:
            self._widgets.set_state(rerun_data.widget_state)

        if config.get_option("runner.installTracer"):
            self._install_tracer()

        # This will be set to a RerunData instance if our execution
        # is interrupted by a RerunException.
        rerun_with_data = None

        try:
            # Create fake module. This gives us a name global namespace to
            # execute the code in.
            module = _new_module("__main__")

            # Install the fake module as the __main__ module. This allows
            # the pickle module to work inside the user's code, since it now
            # can know the module where the pickled objects stem from.
            # IMPORTANT: This means we can't use "if __name__ == '__main__'" in
            # our code, as it will point to the wrong module!!!
            sys.modules["__main__"] = module

            # Add special variables to the module's globals dict.
            # Note: The following is a requirement for the CodeHasher to
            # work correctly. The CodeHasher is scoped to
            # files contained in the directory of __main__.__file__, which we
            # assume is the main script directory.
            module.__dict__["__file__"] = self._report.script_path

            with modified_sys_path(self._report), self._set_execing_flag():
                exec(code, module.__dict__)

        except RerunException as e:
            rerun_with_data = e.rerun_data

        except StopException:
            pass

        except BaseException as e:
            # Show exceptions in the Streamlit report.
            LOGGER.debug(e)
            import streamlit as st

            st.exception(e)  # This is OK because we're in the script thread.
            # TODO: Clean up the stack trace, so it doesn't include
            # ScriptRunner.

        finally:
            self._widgets.reset_triggers()
            self.on_event.send(ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS)

        # Use _log_if_error() to make sure we never ever ever stop running the
        # script without meaning to.
        _log_if_error(_clean_problem_modules)

        if rerun_with_data is not None:
            self._run_script(rerun_with_data)
Esempio n. 16
0
def _reset(main_dg, sidebar_dg):
    main_dg._reset()
    sidebar_dg._reset()
    global sidebar
    sidebar = sidebar_dg
    get_report_ctx().widget_ids_this_run.clear()