def __init__(self, parent=None):
        gui.OWComponent.__init__(self, widget=parent)
        pg.PlotWidget.__init__(self, parent=parent, viewBox=KaplanMeierViewBox(self))

        self.setLabels(left='Survival Probability', bottom='Time')

        self.highlighted_curve: Optional[int] = None
        self.curves: Dict[int, EstimatedFunctionCurve] = {}
        self.__selection_items: Dict[int, Optional[pg.PlotDataItem]] = {}

        self.view_box: KaplanMeierViewBox = self.getViewBox()

        self._mouse_moved_signal = pg.SignalProxy(
            self.plotItem.scene().sigMouseMoved, slot=self.mouseMovedEvent, delay=0.15, rateLimit=10
        )
        self.view_box.selection_changed.connect(self.on_selection_changed)

        self.legend = LegendItem()
        self.legend.setParentItem(self.getViewBox())
        self.legend.restoreAnchor(((1, 0), (1, 0)))
class KaplanMeierPlot(gui.OWComponent, pg.PlotWidget):
    HIGHLIGHT_RADIUS = 20  # in pixels
    selection_changed = Signal()

    selection: Dict[int,
                    Optional[SelectionInterval]] = Setting({},
                                                           schema_only=True)

    def __init__(self, parent=None):
        gui.OWComponent.__init__(self, widget=parent)
        pg.PlotWidget.__init__(self,
                               parent=parent,
                               viewBox=KaplanMeierViewBox(self))

        self.setLabels(left='Survival Probability', bottom='Time')

        self.highlighted_curve: Optional[int] = None
        self.curves: Dict[int, EstimatedFunctionCurve] = {}
        self.__selection_items: Dict[int, Optional[pg.PlotDataItem]] = {}

        self.view_box: KaplanMeierViewBox = self.getViewBox()

        self._mouse_moved_signal = pg.SignalProxy(
            self.plotItem.scene().sigMouseMoved,
            slot=self.mouseMovedEvent,
            delay=0.15,
            rateLimit=10)
        self.view_box.selection_changed.connect(self.on_selection_changed)

        self.legend = LegendItem()
        self.legend.setParentItem(self.getViewBox())
        self.legend.restoreAnchor(((1, 0), (1, 0)))

    def mouseMovedEvent(self, ev):
        pos = self.view_box.mapSceneToView(ev[0])
        mouse_x_pos, mouse_y_pos = pos.x(), pos.y()

        x_pixel_size, y_pixel_size = self.view_box.viewPixelSize()
        x_pixel = self.HIGHLIGHT_RADIUS * x_pixel_size
        y_pixel = self.HIGHLIGHT_RADIUS * y_pixel_size

        for curve_id, curve in self.curves.items():

            points = np.column_stack((curve.x, curve.y))
            line_segments = np.column_stack((points[:-1, :], points[1:, :]))

            mask = np.argwhere(line_segments[:, 0] != line_segments[:, 2])
            horizontal_segments = np.squeeze(line_segments[mask], axis=1)

            mask = np.argwhere(line_segments[:, 1] != line_segments[:, 3])
            vertical_segments = np.squeeze(line_segments[mask], axis=1)

            mouse_on_horizontal_segment = (
                # check X axis
                (horizontal_segments[:, 0] < mouse_x_pos)
                & (mouse_x_pos < horizontal_segments[:, 2])
                # check Y axis
                & (horizontal_segments[:, 1] + y_pixel > mouse_y_pos)
                & (mouse_y_pos > horizontal_segments[:, 3] - y_pixel))
            mouse_on_vertical_segment = (
                # check X axis
                (vertical_segments[:, 0] - x_pixel < mouse_x_pos)
                & (mouse_x_pos < vertical_segments[:, 2] + x_pixel)
                # check Y axis
                & (vertical_segments[:, 1] > mouse_y_pos)
                & (mouse_y_pos > vertical_segments[:, 3]))

            if np.any(mouse_on_horizontal_segment) | np.any(
                    mouse_on_vertical_segment):
                self.highlight(curve_id)
                return
            else:
                self.highlight(None)

    def highlight(self, curve_id: Optional[int]):
        old = self.highlighted_curve

        self.highlighted_curve = curve_id
        if self.highlighted_curve is None and old is not None:
            curve = self.curves[old]
            curve.set_highlighted(False)
            return

        if old != self.highlighted_curve:
            curve = self.curves[curve_id]
            curve.set_highlighted(True)

    def clear_selection(self, curve_id: Optional[int] = None):
        """ If curve id is None clear all else clear only highlighted curve """
        if curve_id is not None:
            self.curves[curve_id].selection.hide()
            self.selection = {
                key: val
                for key, val in self.selection.items() if key != curve_id
            }
            return

        for curve in self.curves.values():
            curve.selection.hide()

        self.selection = {}

    def set_selection(self):
        for curve_id in self.selection.keys():
            self.set_selection_item(curve_id)

    def set_selection_item(self, curve_id: int):
        if curve_id not in self.selection:
            return

        selection = self.selection[curve_id]
        curve = self.curves[curve_id]
        curve.selection.setData(selection.x, selection.y)
        curve.selection.show()

    def on_selection_changed(self, selection_interval, is_finished):
        self.clear_selection(self.highlighted_curve)
        if self.highlighted_curve is None or not selection_interval:
            if is_finished:
                self.selection_changed.emit()
            return

        curve = self.curves[self.highlighted_curve]
        start_x, end_x = sorted(selection_interval)
        if end_x < curve.x[0] or start_x > curve.x[-1]:
            return

        start_x = max(curve.x[0], start_x)
        end_x = min(curve.x[-1], end_x)
        left, right = np.argmax(curve.x > start_x), np.argmax(curve.x > end_x)
        right = right if right else -1

        left_selection = (start_x, curve.y[left])
        right_selection = (end_x, curve.y[right])
        middle_selection = np.column_stack((curve.x, curve.y))[left:right]
        selected = np.vstack(
            (left_selection, middle_selection, right_selection))
        self.selection[self.highlighted_curve] = SelectionInterval(
            selected[:, 0], selected[:, 1])
        self.set_selection_item(self.highlighted_curve)

        if is_finished:
            self.selection_changed.emit()

    def update_plot(self,
                    confidence_interval=False,
                    median=False,
                    censored=False):
        self.clear()
        self.legend.clear()

        if not self.curves:
            return

        if median:
            self.addItem(HORIZONTAL_LINE)

        for curve in self.curves.values():
            self.addItem(curve.estimated_fun)
            self.addItem(curve.selection)

            if confidence_interval:
                self.addItem(curve.lower_conf_limit)
                self.addItem(curve.upper_conf_limit)
                self.addItem(curve.confidence_interval)

            if median:
                self.addItem(curve.median_vertical)

            if censored:
                self.addItem(curve.censored_data)

        self.set_selection()
        self.update_legend()

    def update_legend(self):
        self.legend.hide()

        for curve in [
                curve for curve in self.curves.values()
                if curve.color and curve.label
        ]:
            c = QColor(*curve.color)
            dot = pg.ScatterPlotItem(pen=c, brush=c, size=10, symbol='s')
            self.legend.addItem(dot, escape(curve.label))

        if bool(len(self.legend.items)):
            self.legend.show()

    def select_button_clicked(self):
        self.view_box.mode = SELECT
        self.view_box.setMouseMode(self.view_box.RectMode)

    def pan_button_clicked(self):
        self.view_box.mode = PANNING
        self.view_box.setMouseMode(self.view_box.PanMode)

    def zoom_button_clicked(self):
        self.view_box.mode = ZOOMING
        self.view_box.setMouseMode(self.view_box.RectMode)

    def reset_button_clicked(self):
        self.view_box.autoRange()
        self.view_box.enableAutoRange()
示例#3
0
 def _create_legend(self, anchor, brush=QBrush(QColor(232, 232, 232, 200))):
     # by default the legend transparency was to high for colorful maps
     legend = LegendItem(brush=brush)
     legend.setParentItem(self.plot_widget.getViewBox())
     legend.restoreAnchor(anchor)
     return legend
示例#4
0
 def _create_legend(self, anchor):
     legend = LegendItem()
     legend.setParentItem(self.plot_widget.getViewBox())
     legend.restoreAnchor(anchor)
     return legend
示例#5
0
 def _create_legend(self):
     legend = LegendItem()
     legend.setParentItem(self.getViewBox())
     legend.anchor((1, 0), (1, 0), offset=(-3, 1))
     legend.hide()
     return legend