예제 #1
0
    def setUp(self, override_root=True):
        self.report_queue = ReportQueue()

        if override_root:
            dg = self.new_delta_generator()
            setattr(threading.current_thread(), REPORT_CONTEXT_ATTR_NAME,
                    MockReportContext(dg))
    def __init__(self, script_name):
        """Initializes the ScriptRunner for the given script_name"""
        # DeltaGenerator deltas will be enqueued into self.report_queue.
        self.report_queue = ReportQueue()

        def enqueue_fn(msg):
            self.report_queue.enqueue(msg)
            self.maybe_handle_execution_control_request()

        self.script_request_queue = ScriptRequestQueue()
        script_path = os.path.join(os.path.dirname(__file__), "test_data",
                                   script_name)

        super(TestScriptRunner, self).__init__(
            session_id="test session id",
            report=Report(script_path, "test command line"),
            enqueue_forward_msg=enqueue_fn,
            widget_states=WidgetStates(),
            request_queue=self.script_request_queue,
        )

        # Accumulates uncaught exceptions thrown by our run thread.
        self.script_thread_exceptions = []

        # Accumulates all ScriptRunnerEvents emitted by us.
        self.events = []

        def record_event(event, **kwargs):
            self.events.append(event)

        self.on_event.connect(record_event, weak=False)
예제 #3
0
    def __init__(self, script_name):
        """Initializes the ScriptRunner for the given script_name"""
        # DeltaGenerator deltas will be enqueued into self.report_queue.
        self.report_queue = ReportQueue()

        def enqueue_fn(msg):
            self.report_queue.enqueue(msg)
            self.maybe_handle_execution_control_request()
            return True

        self.main_dg = DeltaGenerator(enqueue_fn, container=BlockPath.MAIN)
        self.sidebar_dg = DeltaGenerator(enqueue_fn,
                                         container=BlockPath.SIDEBAR)
        self.script_request_queue = ScriptRequestQueue()

        script_path = os.path.join(os.path.dirname(__file__), 'test_data',
                                   script_name)

        super(TestScriptRunner,
              self).__init__(report=Report(script_path, []),
                             main_dg=self.main_dg,
                             sidebar_dg=self.sidebar_dg,
                             widget_states=WidgetStates(),
                             request_queue=self.script_request_queue)

        # Accumulates uncaught exceptions thrown by our run thread.
        self.script_thread_exceptions = []

        # Accumulates all ScriptRunnerEvents emitted by us.
        self.events = []

        def record_event(event, **kwargs):
            self.events.append(event)

        self.on_event.connect(record_event, weak=False)
예제 #4
0
class DeltaGeneratorTestCase(unittest.TestCase):
    def setUp(self, override_root=True):
        self.report_queue = ReportQueue()

        if override_root:
            dg = self.new_delta_generator()
            setattr(threading.current_thread(), REPORT_CONTEXT_ATTR_NAME,
                    MockReportContext(dg))

    def tearDown(self):
        self.report_queue._clear()
        if hasattr(threading.current_thread(), REPORT_CONTEXT_ATTR_NAME):
            delattr(threading.current_thread(), REPORT_CONTEXT_ATTR_NAME)

    def new_delta_generator(self, *args, **kwargs):
        def enqueue_fn(msg):
            self.report_queue.enqueue(msg)
            return True

        if len(args) > 0:
            enqueue = args[0]
            args = args[1:]
        elif 'enqueue' in kwargs:
            enqueue = kwargs.pop('enqueue')
        else:
            enqueue = enqueue_fn

        return DeltaGenerator(enqueue, *args, **kwargs)

    def get_message_from_queue(self, index=-1):
        return self.report_queue._queue[index]

    def get_delta_from_queue(self, index=-1):
        return self.report_queue._queue[index].delta
예제 #5
0
    def __init__(self, script_path, command_line):
        """Constructor.

        Parameters
        ----------
        script_path : str
            Path of the Python file from which this app is generated.

        command_line : string
            Command line as input by the user

        """
        basename = os.path.basename(script_path)

        self.script_path = os.path.abspath(script_path)
        self.script_folder = os.path.dirname(self.script_path)
        self.name = os.path.splitext(basename)[0]

        # The master queue contains all messages that comprise the report.
        # If the user chooses to share a saved version of the report,
        # we serialize the contents of the master queue.
        self._master_queue = ReportQueue()

        # The browser queue contains messages that haven't yet been
        # delivered to the browser. Periodically, the server flushes
        # this queue and delivers its contents to the browser.
        self._browser_queue = ReportQueue()

        self.report_id = None
        self.generate_new_id()

        self.command_line = command_line
예제 #6
0
    def setUp(self):
        self._report_queue = ReportQueue()

        def enqueue(msg):
            self._report_queue.enqueue(msg)
            return True

        self._dg = DeltaGenerator(enqueue)
예제 #7
0
class DeltaGeneratorTestCase(unittest.TestCase):
    def setUp(self, override_root=True):
        self.report_queue = ReportQueue()

        if override_root:
            main_dg = self.new_delta_generator()
            sidebar_dg = self.new_delta_generator(container=BlockPath.SIDEBAR)
            setattr(
                threading.current_thread(),
                REPORT_CONTEXT_ATTR_NAME,
                ReportContext(
                    main_dg=main_dg,
                    sidebar_dg=sidebar_dg,
                    widgets=Widgets(),
                    widget_ids_this_run=_WidgetIDSet(),
                    uploaded_file_mgr=None,
                ),
            )

    def tearDown(self):
        self.report_queue._clear()
        if hasattr(threading.current_thread(), REPORT_CONTEXT_ATTR_NAME):
            delattr(threading.current_thread(), REPORT_CONTEXT_ATTR_NAME)

    def new_delta_generator(self, *args, **kwargs):
        def enqueue_fn(msg):
            self.report_queue.enqueue(msg)
            return True

        if len(args) > 0:
            enqueue = args[0]
            args = args[1:]
        elif "enqueue" in kwargs:
            enqueue = kwargs.pop("enqueue")
        else:
            enqueue = enqueue_fn

        return DeltaGenerator(enqueue, *args, **kwargs)

    def get_message_from_queue(self, index=-1):
        """Get a ForwardMsg proto from the queue, by index.

        Returns
        -------
        ForwardMsg
        """
        return self.report_queue._queue[index]

    def get_delta_from_queue(self, index=-1):
        """Get a Delta proto from the queue, by index.

        Returns
        -------
        Delta
        """
        return self.report_queue._queue[index].delta
예제 #8
0
    def setUp(self, override_root=True):
        self.report_queue = ReportQueue()

        if override_root:
            main_dg = self.new_delta_generator()
            sidebar_dg = self.new_delta_generator()
            setattr(threading.current_thread(),
                    REPORT_CONTEXT_ATTR_NAME,
                    ReportContext(main_dg=main_dg, sidebar_dg=sidebar_dg,
                                  widgets=Widgets()))
예제 #9
0
class DeltaGeneratorTestCase(unittest.TestCase):
    def setUp(self, override_root=True):
        self.report_queue = ReportQueue()
        self.override_root = override_root
        self.orig_report_ctx = None

        if self.override_root:
            self.orig_report_ctx = get_report_ctx()
            add_report_ctx(
                threading.current_thread(),
                ReportContext(
                    session_id="test session id",
                    enqueue=self.report_queue.enqueue,
                    widgets=Widgets(),
                    widget_ids_this_run=_WidgetIDSet(),
                    uploaded_file_mgr=UploadedFileManager(),
                ),
            )

    def tearDown(self):
        self.clear_queue()
        if self.override_root:
            add_report_ctx(threading.current_thread(), self.orig_report_ctx)

    def get_message_from_queue(self, index=-1):
        """Get a ForwardMsg proto from the queue, by index.

        Returns
        -------
        ForwardMsg
        """
        return self.report_queue._queue[index]

    def get_delta_from_queue(self, index=-1):
        """Get a Delta proto from the queue, by index.

        Returns
        -------
        Delta
        """
        deltas = self.get_all_deltas_from_queue()
        return deltas[index]

    def get_all_deltas_from_queue(self):
        """Return all the delta messages in our ReportQueue"""
        return [
            msg.delta for msg in self.report_queue._queue
            if msg.HasField("delta")
        ]

    def clear_queue(self):
        self.report_queue._clear()
예제 #10
0
    def setUp(self, override_root=True):
        self.report_queue = ReportQueue()
        self.override_root = override_root
        self.orig_report_ctx = None

        if self.override_root:
            self.orig_report_ctx = get_report_ctx()
            add_report_ctx(
                threading.current_thread(),
                ReportContext(
                    enqueue=self.report_queue.enqueue,
                    widgets=Widgets(),
                    widget_ids_this_run=_WidgetIDSet(),
                    uploaded_file_mgr=UploadedFileManager(),
                ),
            )
예제 #11
0
    def setUp(self, override_root=True):
        self.report_queue = ReportQueue()

        if override_root:
            main_dg = self.new_delta_generator()
            sidebar_dg = self.new_delta_generator(container=BlockPath.SIDEBAR)
            setattr(
                threading.current_thread(),
                REPORT_CONTEXT_ATTR_NAME,
                ReportContext(
                    main_dg=main_dg,
                    sidebar_dg=sidebar_dg,
                    widgets=Widgets(),
                    widget_ids_this_run=_WidgetIDSet(),
                ),
            )
예제 #12
0
    def test_replace_element(self):
        rq = ReportQueue()
        self.assertTrue(rq.is_empty())

        rq.enqueue(INIT_MSG)

        TEXT_DELTA_MSG1.delta.id = 0
        rq.enqueue(TEXT_DELTA_MSG1)

        TEXT_DELTA_MSG2.delta.id = 0
        rq.enqueue(TEXT_DELTA_MSG2)

        queue = rq.flush()
        self.assertEqual(len(queue), 2)
        self.assertTrue(queue[0].initialize.config.sharing_enabled)
        self.assertEqual(queue[1].delta.id, 0)
        self.assertEqual(queue[1].delta.new_element.text.body, 'text2')
예제 #13
0
    def test_simple_enqueue(self):
        rq = ReportQueue()
        self.assertTrue(rq.is_empty())

        rq.enqueue(INIT_MSG)

        self.assertFalse(rq.is_empty())
        queue = rq.flush()
        self.assertTrue(rq.is_empty())
        self.assertEqual(len(queue), 1)
        self.assertTrue(queue[0].initialize.config.sharing_enabled)
예제 #14
0
    def test_enqueue_three(self):
        rq = ReportQueue()
        self.assertTrue(rq.is_empty())

        rq.enqueue(INIT_MSG)

        TEXT_DELTA_MSG1.metadata.delta_id = 0
        rq.enqueue(TEXT_DELTA_MSG1)

        TEXT_DELTA_MSG2.metadata.delta_id = 1
        rq.enqueue(TEXT_DELTA_MSG2)

        queue = rq.flush()
        self.assertEqual(len(queue), 3)
        self.assertTrue(queue[0].initialize.config.sharing_enabled)
        self.assertEqual(queue[1].metadata.delta_id, 0)
        self.assertEqual(queue[1].delta.new_element.text.body, "text1")
        self.assertEqual(queue[2].metadata.delta_id, 1)
        self.assertEqual(queue[2].delta.new_element.text.body, "text2")
예제 #15
0
    def test_multiple_containers(self):
        """Deltas should only be coalesced if they're in the same container"""
        rq = ReportQueue()
        self.assertTrue(rq.is_empty())

        rq.enqueue(INIT_MSG)

        def enqueue_deltas(container, path):
            # We deep-copy the protos because we mutate each one
            # multiple times.
            msg = copy.deepcopy(TEXT_DELTA_MSG1)
            msg.delta.id = 0
            msg.delta.parent_block.container = container
            msg.delta.parent_block.path[:] = path
            rq.enqueue(msg)

            msg = copy.deepcopy(DF_DELTA_MSG)
            msg.delta.id = 1
            msg.delta.parent_block.container = container
            msg.delta.parent_block.path[:] = path
            rq.enqueue(msg)

            msg = copy.deepcopy(ADD_ROWS_MSG)
            msg.delta.id = 1
            msg.delta.parent_block.container = container
            msg.delta.parent_block.path[:] = path
            rq.enqueue(msg)

        enqueue_deltas(BlockPath.MAIN, [])
        enqueue_deltas(BlockPath.SIDEBAR, [0, 0, 1])

        def assert_deltas(container, path, idx):
            self.assertEqual(0, queue[idx].delta.id)
            self.assertEqual(container,
                             queue[idx].delta.parent_block.container)
            self.assertEqual(path, queue[idx].delta.parent_block.path)
            self.assertEqual('text1', queue[idx].delta.new_element.text.body)

            self.assertEqual(1, queue[idx + 1].delta.id)
            self.assertEqual(container,
                             queue[idx + 1].delta.parent_block.container)
            self.assertEqual(path, queue[idx + 1].delta.parent_block.path)
            col0 = queue[idx + 1] \
                .delta.new_element.data_frame.data.cols[0].int64s.data
            col1 = queue[idx + 1] \
                .delta.new_element.data_frame.data.cols[1].int64s.data
            self.assertEqual([0, 1, 2, 3, 4, 5], col0)
            self.assertEqual([10, 11, 12, 13, 14, 15], col1)

        queue = rq.flush()
        self.assertEqual(5, len(queue))
        self.assertTrue(queue[0].initialize.config.sharing_enabled)

        assert_deltas(BlockPath.MAIN, [], 1)
        assert_deltas(BlockPath.SIDEBAR, [0, 0, 1], 3)
예제 #16
0
class TestScriptRunner(ScriptRunner):
    """Subclasses ScriptRunner to provide some testing features."""

    def __init__(self, script_name):
        """Initializes the ScriptRunner for the given script_name"""
        # DeltaGenerator deltas will be enqueued into self.report_queue.
        self.report_queue = ReportQueue()

        def enqueue_fn(msg):
            self.report_queue.enqueue(msg)
            self.maybe_handle_execution_control_request()
            return True

        self.main_dg = DeltaGenerator(enqueue=enqueue_fn, container=BlockPath.MAIN)
        self.sidebar_dg = DeltaGenerator(
            enqueue=enqueue_fn, container=BlockPath.SIDEBAR
        )
        self.script_request_queue = ScriptRequestQueue()

        script_path = os.path.join(os.path.dirname(__file__), "test_data", script_name)

        super(TestScriptRunner, self).__init__(
            report=Report(script_path, "test command line"),
            main_dg=self.main_dg,
            sidebar_dg=self.sidebar_dg,
            widget_states=WidgetStates(),
            request_queue=self.script_request_queue,
        )

        # Accumulates uncaught exceptions thrown by our run thread.
        self.script_thread_exceptions = []

        # Accumulates all ScriptRunnerEvents emitted by us.
        self.events = []

        def record_event(event, **kwargs):
            self.events.append(event)

        self.on_event.connect(record_event, weak=False)

    @property
    def widget_states(self):
        """
        Returns
        -------
        WidgetStates
            A WidgetStates protobuf object

        """
        return self._widgets.get_state()

    def enqueue_rerun(self, argv=None, widget_state=None):
        self.script_request_queue.enqueue(
            ScriptRequest.RERUN, RerunData(widget_state=widget_state)
        )

    def enqueue_stop(self):
        self.script_request_queue.enqueue(ScriptRequest.STOP)

    def enqueue_shutdown(self):
        self.script_request_queue.enqueue(ScriptRequest.SHUTDOWN)

    def _process_request_queue(self):
        try:
            super(TestScriptRunner, self)._process_request_queue()
        except BaseException as e:
            self.script_thread_exceptions.append(e)

    def join(self):
        """Joins the run thread, if it was started"""
        if self._script_thread is not None:
            self._script_thread.join()

    def deltas(self):
        """Returns the delta messages in our ReportQueue"""
        return [msg.delta for msg in self.report_queue._queue if msg.HasField("delta")]

    def get_widget_id(self, widget_type, label):
        """Returns the id of the widget with the specified type and label"""
        for delta in self.deltas():
            new_element = getattr(delta, "new_element", None)
            widget = getattr(new_element, widget_type, None)
            widget_label = getattr(widget, "label", None)
            if widget_label == label:
                return widget.id
        return None
예제 #17
0
    def test_add_rows_rerun(self):
        rq = ReportQueue()
        self.assertTrue(rq.is_empty())

        rq.enqueue(INIT_MSG)

        # Simulate rerun
        for i in range(2):
            TEXT_DELTA_MSG1.delta.id = 0
            rq.enqueue(TEXT_DELTA_MSG1)

            DF_DELTA_MSG.delta.id = 1
            rq.enqueue(DF_DELTA_MSG)

            ADD_ROWS_MSG.delta.id = 1
            rq.enqueue(ADD_ROWS_MSG)

        queue = rq.flush()
        self.assertEqual(len(queue), 3)
        self.assertTrue(queue[0].initialize.config.sharing_enabled)
        self.assertEqual(queue[1].delta.id, 0)
        self.assertEqual(queue[1].delta.new_element.text.body, 'text1')
        self.assertEqual(queue[2].delta.id, 1)
        col0 = queue[2].delta.new_element.data_frame.data.cols[0].int64s.data
        col1 = queue[2].delta.new_element.data_frame.data.cols[1].int64s.data
        self.assertEqual(col0, [0, 1, 2, 3, 4, 5])
        self.assertEqual(col1, [10, 11, 12, 13, 14, 15])
예제 #18
0
class Report(object):
    """
    Contains parameters related to running a report, and also houses
    the two ReportQueues (master_queue and browser_queue) that are used
    to deliver messages to a connected browser, and to serialize the
    running report.
    """
    @classmethod
    def get_url(cls, host_ip):
        """Get the URL for any app served at the given host_ip.

        Parameters
        ----------
        host_ip : str
            The IP address of the machine that is running the Streamlit Server.

        Returns
        -------
        str
            The URL.
        """
        port = _get_browser_address_bar_port()
        return "http://%(host_ip)s:%(port)s" % {
            "host_ip": host_ip,
            "port": port
        }

    def __init__(self, script_path, argv):
        """Constructor.

        Parameters
        ----------
        script_path : str
            Path of the Python file from which this app is generated.

        argv : list of str
            Command-line arguments to run the script with.

        """
        basename = os.path.basename(script_path)

        self.script_path = os.path.abspath(script_path)
        self.script_folder = os.path.dirname(self.script_path)
        self.argv = argv
        self.name = os.path.splitext(basename)[0]

        # The master queue contains all messages that comprise the report.
        # If the user chooses to share a saved version of the report,
        # we serialize the contents of the master queue.
        self._master_queue = ReportQueue()

        # The browser queue contains messages that haven't yet been
        # delivered to the browser. Periodically, the server flushes
        # this queue and delivers its contents to the browser.
        self._browser_queue = ReportQueue()

        self.report_id = None
        self.generate_new_id()

    def get_debug(self):
        return {"master queue": self._master_queue.get_debug()}

    def parse_argv_from_command_line(self, cmd_line_str):
        """Parses an argv dict for this script from a command line string.

        Parameters
        ----------
        cmd_line_str : str
            The string to parse.

        Returns
        -------
        dict
            An argv dict, suitable for executing this Report with.

        """
        import shlex

        cmd_line_list = shlex.split(cmd_line_str)
        new_script_path = os.path.abspath(cmd_line_list[0])

        if new_script_path != self.script_path:
            raise ValueError("Cannot change script from %s to %s" %
                             (self.script_path, cmd_line_list[0]))

        self.argv = cmd_line_list

    def enqueue(self, msg):
        self._master_queue.enqueue(msg)
        self._browser_queue.enqueue(msg)

    def clear(self):
        # Master_queue retains its initial message; browser_queue is
        # completely cleared.
        initial_msg = self._master_queue.get_initial_msg()
        self._master_queue.clear()
        if initial_msg:
            self._master_queue.enqueue(initial_msg)

        self._browser_queue.clear()

    def flush_browser_queue(self):
        """Clears our browser queue and returns the messages it contained.

        The Server calls this periodically to deliver new messages
        to the browser connected to this report.

        This doesn't affect the master_queue.

        Returns
        -------
        list[ForwardMsg]
            The messages that were removed from the queue and should
            be delivered to the browser.

        """
        return self._browser_queue.flush()

    def generate_new_id(self):
        """Randomly generate an ID representing this report's execution."""
        # Convert to str for Python2
        self.report_id = str(
            base58.b58encode(uuid.uuid4().bytes).decode("utf-8"))

    def serialize_running_report_to_files(self):
        """Return a running report as an easily-serializable list of tuples.

        Returns
        -------
        list of tuples
            See `CloudStorage.save_report_files()` for schema. But as to the
            output of this method, it's just a manifest pointing to the Server
            so browsers who go to the shareable report URL can connect to it
            live.

        """
        LOGGER.debug("Serializing running report")

        manifest = self._build_manifest(
            status="running",
            external_server_ip=util.get_external_ip(),
            internal_server_ip=util.get_internal_ip(),
        )

        manifest_json = json.dumps(manifest).encode("utf-8")

        return [("reports/%s/manifest.json" % self.report_id, manifest_json)]

    def serialize_final_report_to_files(self):
        """Return the report as an easily-serializable list of tuples.

        Returns
        -------
        list of tuples
            See `CloudStorage.save_report_files()` for schema. But as to the
            output of this method, it's (1) a simple manifest and (2) a bunch
            of serialized ForwardMsgs.

        """
        LOGGER.debug("Serializing final report")

        messages = [
            copy.deepcopy(msg) for msg in self._master_queue
            if _should_save_report_msg(msg)
        ]

        first_delta_index = 0
        num_deltas = 0
        for idx in range(len(messages)):
            if messages[idx].HasField("delta"):
                messages[idx].metadata.delta_id = num_deltas
                if num_deltas == 0:
                    first_delta_index = idx
                num_deltas += 1

        manifest = self._build_manifest(
            status="done",
            num_messages=len(messages),
            first_delta_index=first_delta_index,
            num_deltas=num_deltas,
        )

        manifest_json = json.dumps(manifest).encode("utf-8")

        # Build a list of message tuples: (message_location, serialized_message)
        message_tuples = [(
            "reports/%(id)s/%(idx)s.pb" % {
                "id": self.report_id,
                "idx": msg_idx
            },
            msg.SerializeToString(),
        ) for msg_idx, msg in enumerate(messages)]

        manifest_tuples = [("reports/%(id)s/manifest.json" % {
            "id": self.report_id
        }, manifest_json)]

        # Manifest must be at the end, so clients don't connect and read the
        # manifest while the deltas haven't been saved yet.
        return message_tuples + manifest_tuples

    def _build_manifest(
        self,
        status,
        num_messages=None,
        first_delta_index=None,
        num_deltas=None,
        external_server_ip=None,
        internal_server_ip=None,
    ):
        """Build a manifest dict for this report.

        Parameters
        ----------
        status : 'done' or 'running'
            The report status. If the script is still executing, then the
            status should be RUNNING. Otherwise, DONE.
        num_messages : int or None
            Set only when status is DONE. The number of ForwardMsgs that this report
            is made of.
        first_delta_index : int or None
            Set only when status is DONE. The index of our first Delta message
        num_deltas : int or None
            Set only when status is DONE. The number of Delta messages in the report
        external_server_ip : str or None
            Only when status is RUNNING. The IP of the Server's websocket.
        internal_server_ip : str or None
            Only when status is RUNNING. The IP of the Server's websocket.

        Returns
        -------
        dict
            The actual manifest. Schema:
            - localId: str,
            - numMessages: int or None,
            - firstDeltaIndex: int or None,
            - numDeltas: int or None,
            - serverStatus: 'running' or 'done',
            - externalServerIP: str or None,
            - internalServerIP: str or None,
            - serverPort: int

        """
        if status == "running":
            configured_server_address = config.get_option(
                "browser.serverAddress")
        else:
            configured_server_address = None

        return dict(
            name=self.name,
            numMessages=num_messages,
            firstDeltaIndex=first_delta_index,
            numDeltas=num_deltas,
            serverStatus=status,
            configuredServerAddress=configured_server_address,
            externalServerIP=external_server_ip,
            internalServerIP=internal_server_ip,
            # Don't use _get_browser_address_bar_port() here, since we want the
            # websocket port, not the web server port. (These are the same in
            # prod, but different in dev)
            serverPort=config.get_option("browser.serverPort"),
        )
예제 #19
0
class Report(object):
    """
    Contains parameters related to running a report, and also houses
    the two ReportQueues (master_queue and browser_queue) that are used
    to deliver messages to a connected browser, and to serialize the
    running report.
    """
    @classmethod
    def get_url(cls, host_ip):
        """Get the URL for any app served at the given host_ip.

        Parameters
        ----------
        host_ip : str
            The IP address of the machine that is running the Streamlit Server.

        Returns
        -------
        str
            The URL.
        """
        port = _get_browser_address_bar_port()
        base_path = config.get_option("server.baseUrlPath").strip("/")

        if base_path:
            base_path = "/" + base_path

        return "http://%(host_ip)s:%(port)s%(base_path)s" % {
            "host_ip": host_ip.strip("/"),
            "port": port,
            "base_path": base_path,
        }

    def __init__(self, script_path, command_line):
        """Constructor.

        Parameters
        ----------
        script_path : str
            Path of the Python file from which this app is generated.

        command_line : string
            Command line as input by the user

        """
        basename = os.path.basename(script_path)

        self.script_path = os.path.abspath(script_path)
        self.script_folder = os.path.dirname(self.script_path)
        self.name = os.path.splitext(basename)[0]

        # The master queue contains all messages that comprise the report.
        # If the user chooses to share a saved version of the report,
        # we serialize the contents of the master queue.
        self._master_queue = ReportQueue()

        # The browser queue contains messages that haven't yet been
        # delivered to the browser. Periodically, the server flushes
        # this queue and delivers its contents to the browser.
        self._browser_queue = ReportQueue()

        self.report_id = None
        self.generate_new_id()

        self.command_line = command_line

    def get_debug(self):
        return {"master queue": self._master_queue.get_debug()}

    def enqueue(self, msg):
        self._master_queue.enqueue(msg)
        self._browser_queue.enqueue(msg)

    def clear(self):
        # Master_queue retains its initial message; browser_queue is
        # completely cleared.
        initial_msg = self._master_queue.get_initial_msg()
        self._master_queue.clear()
        if initial_msg:
            self._master_queue.enqueue(initial_msg)

        self._browser_queue.clear()

    def flush_browser_queue(self):
        """Clears our browser queue and returns the messages it contained.

        The Server calls this periodically to deliver new messages
        to the browser connected to this report.

        This doesn't affect the master_queue.

        Returns
        -------
        list[ForwardMsg]
            The messages that were removed from the queue and should
            be delivered to the browser.

        """
        return self._browser_queue.flush()

    def generate_new_id(self):
        """Randomly generate an ID representing this report's execution."""
        # Convert to str for Python2
        self.report_id = str(
            base58.b58encode(uuid.uuid4().bytes).decode("utf-8"))

    def serialize_running_report_to_files(self):
        """Return a running report as an easily-serializable list of tuples.

        Returns
        -------
        list of tuples
            See `CloudStorage.save_report_files()` for schema. But as to the
            output of this method, it's just a manifest pointing to the Server
            so browsers who go to the shareable report URL can connect to it
            live.

        """
        LOGGER.debug("Serializing running report")

        manifest = self._build_manifest(
            status=StaticManifest.RUNNING,
            external_server_ip=net_util.get_external_ip(),
            internal_server_ip=net_util.get_internal_ip(),
        )

        return [("reports/%s/manifest.pb" % self.report_id,
                 manifest.SerializeToString())]

    def serialize_final_report_to_files(self):
        """Return the report as an easily-serializable list of tuples.

        Returns
        -------
        list of tuples
            See `CloudStorage.save_report_files()` for schema. But as to the
            output of this method, it's (1) a simple manifest and (2) a bunch
            of serialized ForwardMsgs.

        """
        LOGGER.debug("Serializing final report")

        messages = [
            copy.deepcopy(msg) for msg in self._master_queue
            if _should_save_report_msg(msg)
        ]

        first_delta_index = 0
        num_deltas = 0
        for idx in range(len(messages)):
            if messages[idx].HasField("delta"):
                if num_deltas == 0:
                    first_delta_index = idx
                num_deltas += 1

        manifest = self._build_manifest(status=StaticManifest.DONE,
                                        num_messages=len(messages))

        # Build a list of message tuples: (message_location, serialized_message)
        message_tuples = [(
            "reports/%(id)s/%(idx)s.pb" % {
                "id": self.report_id,
                "idx": msg_idx
            },
            msg.SerializeToString(),
        ) for msg_idx, msg in enumerate(messages)]

        manifest_tuples = [(
            "reports/%(id)s/manifest.pb" % {
                "id": self.report_id
            },
            manifest.SerializeToString(),
        )]

        # Manifest must be at the end, so clients don't connect and read the
        # manifest while the deltas haven't been saved yet.
        return message_tuples + manifest_tuples

    def _build_manifest(
        self,
        status,
        num_messages=None,
        external_server_ip=None,
        internal_server_ip=None,
    ):
        """Build a manifest dict for this report.

        Parameters
        ----------
        status : StaticManifest.ServerStatus
            The report status. If the script is still executing, then the
            status should be RUNNING. Otherwise, DONE.
        num_messages : int or None
            Set only when status is DONE. The number of ForwardMsgs that this report
            is made of.
        external_server_ip : str or None
            Only when status is RUNNING. The IP of the Server's websocket.
        internal_server_ip : str or None
            Only when status is RUNNING. The IP of the Server's websocket.

        Returns
        -------
        StaticManifest
            A StaticManifest protobuf message

        """

        manifest = StaticManifest()
        manifest.name = self.name
        manifest.server_status = status

        if status == StaticManifest.RUNNING:
            manifest.external_server_ip = external_server_ip
            manifest.internal_server_ip = internal_server_ip
            manifest.configured_server_address = config.get_option(
                "browser.serverAddress")
            # Don't use _get_browser_address_bar_port() here, since we want the
            # websocket port, not the web server port. (These are the same in
            # prod, but different in dev)
            manifest.server_port = config.get_option("browser.serverPort")
            manifest.server_base_path = config.get_option("server.baseUrlPath")
        else:
            manifest.num_messages = num_messages

        return manifest
예제 #20
0
    def test_simple_add_rows(self):
        rq = ReportQueue()
        self.assertTrue(rq.is_empty())

        rq.enqueue(INIT_MSG)

        TEXT_DELTA_MSG1.metadata.delta_id = 0
        rq.enqueue(TEXT_DELTA_MSG1)

        DF_DELTA_MSG.metadata.delta_id = 1
        rq.enqueue(DF_DELTA_MSG)

        ADD_ROWS_MSG.metadata.delta_id = 1
        rq.enqueue(ADD_ROWS_MSG)

        queue = rq.flush()
        self.assertEqual(len(queue), 3)
        self.assertTrue(queue[0].initialize.config.sharing_enabled)
        self.assertEqual(queue[1].metadata.delta_id, 0)
        self.assertEqual(queue[1].delta.new_element.text.body, "text1")
        self.assertEqual(queue[2].metadata.delta_id, 1)
        col0 = queue[2].delta.new_element.data_frame.data.cols[0].int64s.data
        col1 = queue[2].delta.new_element.data_frame.data.cols[1].int64s.data
        self.assertEqual(col0, [0, 1, 2, 3, 4, 5])
        self.assertEqual(col1, [10, 11, 12, 13, 14, 15])
예제 #21
0
class DataFrameStylingTest(unittest.TestCase):
    """Tests marshalling of pandas.Styler dataframe styling data
    with both st.dataframe and st.table.
    """
    def setUp(self):
        self._report_queue = ReportQueue()

        def enqueue(msg):
            self._report_queue.enqueue(msg)
            return True

        self._dg = DeltaGenerator(enqueue)

    @parameterized.expand([('dataframe', 'data_frame'), ('table', 'table')])
    def test_unstyled_has_no_style(self, element, proto):
        """A DataFrame with an unmodified Styler should result in a protobuf
        with no styling data
        """

        df = pd.DataFrame({'A': [1, 2, 3, 4, 5]})

        getattr(self._dg, element)(df.style)
        proto_df = getattr(self._get_element(), proto)

        rows, cols = df.shape
        for row in range(rows):
            for col in range(cols):
                style = get_cell_style(proto_df, col, row)
                self.assertEqual(style.display_value, '')
                self.assertEqual(style.has_display_value, False)
                self.assertEqual(len(style.css), 0)

    @parameterized.expand([('dataframe', 'data_frame'), ('table', 'table')])
    def test_format(self, element, proto):
        """Tests DataFrame.style.format()"""
        values = [0.1, 0.2, 0.3352, np.nan]
        display_values = ['10.00%', '20.00%', '33.52%', 'nan%']

        df = pd.DataFrame({'A': values})

        get_delta = getattr(self._dg, element)
        get_delta(df.style.format('{:.2%}'))

        proto_df = getattr(self._get_element(), proto)
        self._assert_column_display_values(proto_df, 0, display_values)

    @parameterized.expand([('dataframe', 'data_frame'), ('table', 'table')])
    def test_css_styling(self, element, proto):
        """Tests DataFrame.style css styling"""

        values = [-1, 1]
        css_values = [{css_s('color', 'red')},
                      {
                          css_s('color', 'black'),
                          css_s('background-color', 'yellow')
                      }]

        df = pd.DataFrame({'A': values})

        get_delta = getattr(self._dg, element)
        get_delta(
            df.style.highlight_max(color='yellow').applymap(
                lambda val: 'color: red' if val < 0 else 'color: black'))

        proto_df = getattr(self._get_element(), proto)
        self._assert_column_css_styles(proto_df, 0, css_values)

    @parameterized.expand([('dataframe', 'data_frame'), ('table', 'table')])
    def test_add_styled_rows(self, element, proto):
        """Add rows should preserve existing styles and append new styles"""
        df1 = pd.DataFrame([5, 6])
        df2 = pd.DataFrame([7, 8])

        css_values = [
            {css_s('color', 'red')},
            {css_s('color', 'red')},
            {css_s('color', 'black')},
            {css_s('color', 'black')},
        ]

        get_delta = getattr(self._dg, element)
        x = get_delta(df1.style.applymap(lambda val: 'color: red'))

        x.add_rows(df2.style.applymap(lambda val: 'color: black'))

        proto_df = getattr(self._get_element(), proto)
        self._assert_column_css_styles(proto_df, 0, css_values)

    @parameterized.expand([('dataframe', 'data_frame'), ('table', 'table')])
    def test_add_styled_rows_to_unstyled_rows(self, element, proto):
        """Adding styled rows to unstyled rows should work"""
        df1 = pd.DataFrame([5, 6])
        df2 = pd.DataFrame([7, 8])

        css_values = [
            set(),
            set(),
            {css_s('color', 'black')},
            {css_s('color', 'black')},
        ]

        x = getattr(self._dg, element)(df1)
        x.add_rows(df2.style.applymap(lambda val: 'color: black'))

        proto_df = getattr(self._get_element(), proto)
        self._assert_column_css_styles(proto_df, 0, css_values)

    @parameterized.expand([('dataframe', 'data_frame'), ('table', 'table')])
    def test_add_unstyled_rows_to_styled_rows(self, element, proto):
        """Adding unstyled rows to styled rows should work"""
        df1 = pd.DataFrame([5, 6])
        df2 = pd.DataFrame([7, 8])

        css_values = [
            {css_s('color', 'black')},
            {css_s('color', 'black')},
            set(),
            set(),
        ]

        get_delta = getattr(self._dg, element)
        x = get_delta(df1.style.applymap(lambda val: 'color: black'))

        x.add_rows(df2)

        proto_df = getattr(self._get_element(), proto)
        self._assert_column_css_styles(proto_df, 0, css_values)

    def _get_element(self):
        """Returns the most recent element in the DeltaGenerator queue"""
        return self._report_queue._queue[-1].delta.new_element

    def _assert_column_display_values(self, proto_df, col, display_values):
        """Asserts that cells in a column have the given display_values"""
        for row in range(len(display_values)):
            style = get_cell_style(proto_df, col, row)
            self.assertEqual(style.has_display_value, display_values[row]
                             is not None)
            self.assertEqual(style.display_value, display_values[row])

    def _assert_column_css_styles(self, proto_df, col, expected_styles):
        """Asserts that cells in a column have the given expected_styles
        expected_styles : List[Set[serialized_proto_str]]
        """
        for row in range(len(expected_styles)):
            proto_cell_style = get_cell_style(proto_df, col, row)
            # throw the `repeated CSSStyle styles` into a set of serialized strings
            cell_styles = set(
                (proto_to_str(css) for css in proto_cell_style.css))
            self.assertEqual(expected_styles[row], cell_styles)