Exemplo n.º 1
0
class OWScatterPlotGraph(gui.OWComponent, ScaleScatterPlotData):
    attr_color = ContextSetting("", ContextSetting.OPTIONAL)
    attr_label = ContextSetting("", ContextSetting.OPTIONAL)
    attr_shape = ContextSetting("", ContextSetting.OPTIONAL)
    attr_size = ContextSetting("", ContextSetting.OPTIONAL)

    point_width = Setting(10)
    alpha_value = Setting(255)
    show_grid = Setting(False)
    show_legend = Setting(True)
    tooltip_shows_all = Setting(False)
    square_granularity = Setting(3)
    space_between_cells = Setting(True)

    CurveSymbols = np.array("o x t + d s ?".split())
    MinShapeSize = 6
    DarkerValue = 120
    UnknownColor = (168, 50, 168)

    def __init__(self, scatter_widget, parent=None, _="None"):
        gui.OWComponent.__init__(self, scatter_widget)
        self.view_box = InteractiveViewBox(self)
        self.plot_widget = pg.PlotWidget(viewBox=self.view_box, parent=parent)
        self.plot_widget.setAntialiasing(True)
        self.replot = self.plot_widget
        ScaleScatterPlotData.__init__(self)
        self.scatterplot_item = None

        self.tooltip_data = []
        self.tooltip = TextItem(
            border=pg.mkPen(200, 200, 200), fill=pg.mkBrush(250, 250, 200, 220))
        self.tooltip.hide()

        self.labels = []

        self.master = scatter_widget
        self.shown_attribute_indices = []
        self.shown_x = ""
        self.shown_y = ""
        self.pen_colors = self.brush_colors = None

        self.valid_data = None  # np.ndarray
        self.selection = None  # np.ndarray
        self.n_points = 0

        self.gui = OWPlotGUI(self)
        self.continuous_palette = ContinuousPaletteGenerator(
            QColor(255, 255, 0), QColor(0, 0, 255), True)
        self.discrete_palette = ColorPaletteGenerator()

        self.selection_behavior = 0

        self.legend = self.color_legend = None
        self.scale = None  # DiscretizedScale

        self.tips = TooltipManager(self)
        # self.setMouseTracking(True)
        # self.grabGesture(QPinchGesture)
        # self.grabGesture(QPanGesture)

        self.update_grid()

    def set_data(self, data, subset_data=None, **args):
        self.plot_widget.clear()
        ScaleScatterPlotData.set_data(self, data, subset_data, **args)

    def update_data(self, attr_x, attr_y):
        self.shown_x = attr_x
        self.shown_y = attr_y

        self.remove_legend()
        if self.scatterplot_item:
            self.plot_widget.removeItem(self.scatterplot_item)
        for label in self.labels:
            self.plot_widget.removeItem(label)
        self.labels = []
        self.tooltip_data = []
        self.set_axis_title("bottom", "")
        self.set_axis_title("left", "")

        if self.scaled_data is None or not len(self.scaled_data):
            self.valid_data = None
            self.n_points = 0
            return

        index_x = self.attribute_name_index[attr_x]
        index_y = self.attribute_name_index[attr_y]
        self.valid_data = self.get_valid_list([index_x, index_y])
        x_data, y_data = self.get_xy_data_positions(
            attr_x, attr_y, self.valid_data)
        x_data = x_data[self.valid_data]
        y_data = y_data[self.valid_data]
        self.n_points = len(x_data)

        for axis, name, index in (("bottom", attr_x, index_x),
                                  ("left", attr_y, index_y)):
            self.set_axis_title(axis, name)
            var = self.data_domain[index]
            if isinstance(var, DiscreteVariable):
                self.set_labels(axis, get_variable_values_sorted(var))

        color_data, brush_data = self.compute_colors()
        size_data = self.compute_sizes()
        shape_data = self.compute_symbols()
        self.scatterplot_item = ScatterPlotItem(
            x=x_data, y=y_data, data=np.arange(self.n_points),
            symbol=shape_data, size=size_data, pen=color_data, brush=brush_data)
        self.plot_widget.addItem(self.scatterplot_item)
        self.plot_widget.addItem(self.tooltip)
        self.scatterplot_item.selected_points = []
        self.scatterplot_item.sigClicked.connect(self.select_by_click)
        self.scatterplot_item.scene().sigMouseMoved.connect(self.mouseMoved)

        self.update_labels()
        self.make_legend()
        self.plot_widget.replot()

    def set_labels(self, axis, labels):
        axis = self.plot_widget.getAxis(axis)
        if labels:
            ticks = [[(i, labels[i]) for i in range(len(labels))]]
            axis.setTicks(ticks)
        else:
            axis.setTicks(None)

    def set_axis_title(self, axis, title):
        self.plot_widget.setLabel(axis=axis, text=title)

    def get_size_index(self):
        size_index = -1
        attr_size = self.attr_size
        if attr_size != "" and attr_size != "(Same size)":
            size_index = self.attribute_name_index[attr_size]
        return size_index

    def compute_sizes(self):
        size_index = self.get_size_index()
        if size_index == -1:
            size_data = np.full((self.n_points,), self.point_width)
        else:
            size_data = \
                self.MinShapeSize + \
                self.no_jittering_scaled_data[size_index] * self.point_width
        size_data[np.isnan(size_data)] = self.MinShapeSize - 2
        return size_data

    def update_sizes(self):
        if self.scatterplot_item:
            size_data = self.compute_sizes()
            self.scatterplot_item.setSize(size_data)

    update_point_size = update_sizes

    def get_color_index(self):
        color_index = -1
        attr_color = self.attr_color
        if attr_color != "" and attr_color != "(Same color)":
            color_index = self.attribute_name_index[attr_color]
            color_var = self.data_domain[attr_color]
            if isinstance(color_var, DiscreteVariable):
                self.discrete_palette.set_number_of_colors(
                    len(color_var.values))
        return color_index

    def compute_colors(self, keep_colors=False):
        if not keep_colors:
            self.pen_colors = self.brush_colors = None
        color_index = self.get_color_index()
        if color_index == -1:
            color = self.plot_widget.palette().color(OWPalette.Data)
            pen = [QPen(QBrush(color), 1.5)] * self.n_points
            if self.selection is not None:
                brush = [(QBrush(QColor(128, 128, 128, 255)),
                          QBrush(QColor(128, 128, 128)))[s]
                         for s in self.selection]
            else:
                brush = [QBrush(QColor(128, 128, 128))] * self.n_points
            return pen, brush

        c_data = self.original_data[color_index, self.valid_data]
        if isinstance(self.data_domain[color_index], ContinuousVariable):
            if self.pen_colors is None:
                self.scale = DiscretizedScale(np.min(c_data), np.max(c_data))
                c_data -= self.scale.offset
                c_data /= self.scale.width
                c_data = np.floor(c_data) + 0.5
                c_data /= self.scale.bins
                c_data = np.clip(c_data, 0, 1)
                palette = self.continuous_palette
                self.pen_colors = palette.getRGB(c_data)
                self.brush_colors = np.hstack(
                    [self.pen_colors,
                     np.full((self.n_points, 1), self.alpha_value)])
                self.pen_colors *= 100 / self.DarkerValue
                self.pen_colors = [QPen(QBrush(QColor(*col)), 1.5)
                                   for col in self.pen_colors.tolist()]
            if self.selection is not None:
                self.brush_colors[:, 3] = 0
                self.brush_colors[self.selection, 3] = self.alpha_value
            else:
                self.brush_colors[:, 3] = self.alpha_value
            pen = self.pen_colors
            brush = np.array([QBrush(QColor(*col))
                              for col in self.brush_colors.tolist()])
        else:
            if self.pen_colors is None:
                palette = self.discrete_palette
                n_colors = palette.number_of_colors
                c_data = c_data.copy()
                c_data[np.isnan(c_data)] = n_colors
                c_data = c_data.astype(int)
                colors = palette.getRGB(np.arange(n_colors + 1))
                colors[n_colors] = (128, 128, 128)
                pens = np.array(
                    [QPen(QBrush(QColor(*col).darker(self.DarkerValue)), 1.5)
                     for col in colors])
                self.pen_colors = pens[c_data]
                self.brush_colors = np.array([
                    [QBrush(QColor(0, 0, 0, 0)),
                     QBrush(QColor(col[0], col[1], col[2], self.alpha_value))]
                    for col in colors])
                self.brush_colors = self.brush_colors[c_data]
            if self.selection is not None:
                brush = np.where(
                    self.selection,
                    self.brush_colors[:, 1], self.brush_colors[:, 0])
            else:
                brush = self.brush_colors[:, 1]
            pen = self.pen_colors
        return pen, brush

    def update_colors(self, keep_colors=False):
        if self.scatterplot_item:
            pen_data, brush_data = self.compute_colors(keep_colors)
            self.scatterplot_item.setPen(pen_data, update=False, mask=None)
            self.scatterplot_item.setBrush(brush_data, mask=None)
            if not keep_colors:
                self.make_legend()

    update_alpha_value = update_colors

    def create_labels(self):
        for x, y in zip(*self.scatterplot_item.getData()):
            ti = TextItem()
            self.plot_widget.addItem(ti)
            ti.setPos(x, y)
            self.labels.append(ti)

    def update_labels(self):
        if not self.attr_label:
            for label in self.labels:
                label.setText("")
            return
        if not self.labels:
            self.create_labels()
        label_column = self.raw_data.get_column_view(self.attr_label)[0]
        formatter = self.raw_data.domain[self.attr_label].str_val
        label_data = map(formatter, label_column)
        black = pg.mkColor(0, 0, 0)
        for label, text in zip(self.labels, label_data):
            label.setText(text, black)

    def get_shape_index(self):
        shape_index = -1
        attr_shape = self.attr_shape
        if attr_shape and attr_shape != "(Same shape)" and \
                len(self.data_domain[attr_shape].values) <= \
                len(self.CurveSymbols):
            shape_index = self.attribute_name_index[attr_shape]
        return shape_index

    def compute_symbols(self):
        shape_index = self.get_shape_index()
        if shape_index == -1:
            shape_data = self.CurveSymbols[np.zeros(self.n_points, dtype=int)]
        else:
            shape_data = self.original_data[shape_index]
            shape_data[np.isnan(shape_data)] = len(self.CurveSymbols) - 1
            shape_data = self.CurveSymbols[shape_data.astype(int)]
        return shape_data

    def update_shapes(self):
        if self.scatterplot_item:
            shape_data = self.compute_symbols()
            self.scatterplot_item.setSymbol(shape_data)
        self.make_legend()

    def update_grid(self):
        self.plot_widget.showGrid(x=self.show_grid, y=self.show_grid)

    def update_legend(self):
        if self.legend:
            self.legend.setVisible(self.show_legend)

    def create_legend(self):
        self.legend = PositionedLegendItem(self.plot_widget.plotItem, self)

    def remove_legend(self):
        if self.legend:
            self.legend.setParent(None)
            self.legend = None
        if self.color_legend:
            self.color_legend.setParent(None)
            self.color_legend = None

    def make_legend(self):
        self.remove_legend()
        self.make_color_legend()
        self.make_shape_legend()
        self.update_legend()

    def make_color_legend(self):
        color_index = self.get_color_index()
        if color_index == -1:
            return
        color_var = self.data_domain[color_index]
        use_shape = self.get_shape_index() == color_index
        if isinstance(color_var, DiscreteVariable):
            if not self.legend:
                self.create_legend()
            palette = self.discrete_palette
            for i, value in enumerate(color_var.values):
                color = QColor(*palette.getRGB(i))
                brush = color.lighter(self.DarkerValue)
                self.legend.addItem(
                    ScatterPlotItem(
                        pen=color, brush=brush, size=10,
                        symbol=self.CurveSymbols[i] if use_shape else "o"),
                    value)
        else:
            legend = self.color_legend = PositionedLegendItem(
                self.plot_widget.plotItem,
                self, legend_id="colors", at_bottom=True)
            label = PaletteItemSample(self.continuous_palette, self.scale)
            legend.addItem(label, "")
            legend.setGeometry(label.boundingRect())

    def make_shape_legend(self):
        shape_index = self.get_shape_index()
        if shape_index == -1 or shape_index == self.get_color_index():
            return
        if not self.legend:
            self.create_legend()
        shape_var = self.data_domain[shape_index]
        color = self.plot_widget.palette().color(OWPalette.Data)
        pen = QPen(color.darker(self.DarkerValue))
        color.setAlpha(self.alpha_value)
        for i, value in enumerate(shape_var.values):
            self.legend.addItem(
                ScatterPlotItem(pen=pen, brush=color, size=10,
                                symbol=self.CurveSymbols[i]), value)

    # noinspection PyPep8Naming
    def mouseMoved(self, pos):
        act_pos = self.scatterplot_item.mapFromScene(pos)
        points = self.scatterplot_item.pointsAt(act_pos)
        text = ""
        if len(points):
            for i, p in enumerate(points):
                index = p.data()
                text += "Attributes:\n"
                if self.tooltip_shows_all:
                    text += "".join(
                        '   {} = {}\n'.format(attr.name,
                                              self.raw_data[index][attr])
                        for attr in self.data_domain.attributes)
                else:
                    text += '   {} = {}\n   {} = {}\n'.format(
                        self.shown_x, self.raw_data[index][self.shown_x],
                        self.shown_y, self.raw_data[index][self.shown_y])
                if self.data_domain.class_var:
                    text += 'Class:\n   {} = {}\n'.format(
                        self.data_domain.class_var.name,
                        self.raw_data[index][self.raw_data.domain.class_var])
                if i < len(points) - 1:
                    text += '------------------\n'
            self.tooltip.setText(text, color=(0, 0, 0))
            self.tooltip.setPos(act_pos)
            self.tooltip.show()
            self.tooltip.setZValue(10)
        else:
            self.tooltip.hide()

    def zoom_button_clicked(self):
        self.scatterplot_item.getViewBox().setMouseMode(
            self.scatterplot_item.getViewBox().RectMode)

    def pan_button_clicked(self):
        self.scatterplot_item.getViewBox().setMouseMode(
            self.scatterplot_item.getViewBox().PanMode)

    def select_button_clicked(self):
        self.scatterplot_item.getViewBox().setMouseMode(
            self.scatterplot_item.getViewBox().RectMode)

    def reset_button_clicked(self):
        self.view_box.autoRange()

    def select_by_click(self, _, points):
        self.select(points)

    def select_by_rectangle(self, value_rect):
        points = [point
                  for point in self.scatterplot_item.points()
                  if value_rect.contains(QPointF(point.pos()))]
        self.select(points)

    def unselect_all(self):
        self.selection = None
        self.update_colors(keep_colors=True)

    def select(self, points):
        # noinspection PyArgumentList
        keys = QApplication.keyboardModifiers()
        if self.selection is None or not keys & (
                        Qt.ShiftModifier + Qt.ControlModifier + Qt.AltModifier):
            self.selection = np.full(self.n_points, False, dtype=np.bool)
        indices = [p.data() for p in points]
        if keys & Qt.ControlModifier:
            self.selection[indices] = False
        elif keys & Qt.AltModifier:
            self.selection[indices] = 1 - self.selection[indices]
        else:  # Handle shift and no modifiers
            self.selection[indices] = True
        self.update_colors(keep_colors=True)
        self.master.selection_changed()

    def get_selection(self):
        if self.selection is None:
            return np.array([], dtype=int)
        else:
            return np.arange(len(self.raw_data)
                )[self.valid_data][self.selection]

    def set_palette(self, p):
        self.plot_widget.setPalette(p)

    def save_to_file(self, size):
        pass
Exemplo n.º 2
0
class OWScatterPlotGraph(gui.OWComponent, ScaleScatterPlotData):
    attr_color = ContextSetting("", ContextSetting.OPTIONAL)
    attr_label = ContextSetting("", ContextSetting.OPTIONAL)
    attr_shape = ContextSetting("", ContextSetting.OPTIONAL)
    attr_size = ContextSetting("", ContextSetting.OPTIONAL)

    point_width = Setting(10)
    alpha_value = Setting(128)
    show_grid = Setting(False)
    show_legend = Setting(True)
    tooltip_shows_all = Setting(False)
    class_density = Setting(False)
    resolution = 256

    CurveSymbols = np.array("o x t + d s ?".split())
    MinShapeSize = 6
    DarkerValue = 120
    UnknownColor = (168, 50, 168)

    def __init__(self, scatter_widget, parent=None, _="None"):
        gui.OWComponent.__init__(self, scatter_widget)
        self.view_box = InteractiveViewBox(self)
        self.plot_widget = pg.PlotWidget(viewBox=self.view_box, parent=parent,
                                         background="w")
        self.plot_widget.getPlotItem().buttonsHidden = True
        self.plot_widget.setAntialiasing(True)
        self.plot_widget.sizeHint = lambda: QtCore.QSize(500,500)

        self.replot = self.plot_widget.replot
        ScaleScatterPlotData.__init__(self)
        self.density_img = None
        self.scatterplot_item = None
        self.scatterplot_item_sel = None

        self.labels = []

        self.master = scatter_widget
        self.shown_attribute_indices = []
        self.shown_x = ""
        self.shown_y = ""
        self.pen_colors = self.brush_colors = None

        self.valid_data = None  # np.ndarray
        self.selection = None  # np.ndarray
        self.n_points = 0

        self.gui = OWPlotGUI(self)
        self.continuous_palette = ContinuousPaletteGenerator(
            QColor(255, 255, 0), QColor(0, 0, 255), True)
        self.discrete_palette = ColorPaletteGenerator()

        self.selection_behavior = 0

        self.legend = self.color_legend = None
        self.__legend_anchor = (1, 0), (1, 0)
        self.__color_legend_anchor = (1, 1), (1, 1)

        self.scale = None  # DiscretizedScale

        self.subset_indices = None

        # self.setMouseTracking(True)
        # self.grabGesture(QPinchGesture)
        # self.grabGesture(QPanGesture)

        self.update_grid()

        self._tooltip_delegate = HelpEventDelegate(self.help_event)
        self.plot_widget.scene().installEventFilter(self._tooltip_delegate)

    def new_data(self, data, subset_data=None, **args):
        self.plot_widget.clear()

        self.density_img = None
        self.scatterplot_item = None
        self.scatterplot_item_sel = None
        self.labels = []
        self.selection = None
        self.valid_data = None

        self.subset_indices = set(e.id for e in subset_data) if subset_data else None

        self.set_data(data, **args)

    def update_data(self, attr_x, attr_y, reset_view=True):
        self.shown_x = attr_x
        self.shown_y = attr_y

        self.remove_legend()
        if self.density_img:
            self.plot_widget.removeItem(self.density_img)
            self.density_img = None
        if self.scatterplot_item:
            self.plot_widget.removeItem(self.scatterplot_item)
            self.scatterplot_item = None
        if self.scatterplot_item_sel:
            self.plot_widget.removeItem(self.scatterplot_item_sel)
            self.scatterplot_item_sel = None
        for label in self.labels:
            self.plot_widget.removeItem(label)
        self.labels = []
        self.set_axis_title("bottom", "")
        self.set_axis_title("left", "")

        if self.scaled_data is None or not len(self.scaled_data):
            self.valid_data = None
            self.selection = None
            self.n_points = 0
            return

        index_x = self.attribute_name_index[attr_x]
        index_y = self.attribute_name_index[attr_y]
        self.valid_data = self.get_valid_list([index_x, index_y],
                                              also_class_if_exists=False)
        x_data, y_data = self.get_xy_data_positions(
            attr_x, attr_y, self.valid_data)
        self.n_points = len(x_data)

        if reset_view:
            min_x, max_x = np.nanmin(x_data), np.nanmax(x_data)
            min_y, max_y = np.nanmin(y_data), np.nanmax(y_data)
            self.view_box.setRange(
                QRectF(min_x, min_y, max_x - min_x, max_y - min_y),
                padding=0.025)
            self.view_box.init_history()
            self.view_box.tag_history()
        [min_x, max_x], [min_y, max_y] = self.view_box.viewRange()

        for axis, name, index in (("bottom", attr_x, index_x),
                                  ("left", attr_y, index_y)):
            self.set_axis_title(axis, name)
            var = self.data_domain[index]
            if var.is_discrete:
                self.set_labels(axis, get_variable_values_sorted(var))
            else:
                self.set_labels(axis, None)

        color_data, brush_data = self.compute_colors()
        color_data_sel, brush_data_sel = self.compute_colors_sel()
        size_data = self.compute_sizes()
        shape_data = self.compute_symbols()

        if self.should_draw_density():
            rgb_data = [pen.color().getRgb()[:3] for pen in color_data]
            self.density_img = classdensity.class_density_image(min_x, max_x, min_y, max_y, self.resolution,
                                                                x_data, y_data, rgb_data)
            self.plot_widget.addItem(self.density_img)

        data_indices = np.flatnonzero(self.valid_data)
        self.scatterplot_item = ScatterPlotItem(
            x=x_data, y=y_data, data=data_indices,
            symbol=shape_data, size=size_data, pen=color_data, brush=brush_data
        )
        self.scatterplot_item_sel = ScatterPlotItem(
            x=x_data, y=y_data, data=data_indices,
            symbol=shape_data, size=size_data + SELECTION_WIDTH,
            pen=color_data_sel, brush=brush_data_sel
        )
        self.plot_widget.addItem(self.scatterplot_item_sel)
        self.plot_widget.addItem(self.scatterplot_item)

        self.scatterplot_item.selected_points = []
        self.scatterplot_item.sigClicked.connect(self.select_by_click)
        # The hook below used to be used by biolab.
        # Now it is only used by the ROI propagation, so
        # A better solution should be found.
        # TODO: Find a better solution to this hook.
        self.scatterplot_item.scene().sigMouseMoved.connect(self.mouseMoved)

        self.update_labels()
        self.make_legend()
        self.plot_widget.replot()

    def can_draw_density(self):
        if self.data_domain is None:
            return False
        discrete_color = False
        attr_color = self.attr_color
        if attr_color != "" and attr_color != "(Same color)":
            color_var = self.data_domain[attr_color]
            discrete_color = color_var.is_discrete
        continuous_x = False
        continuous_y = False
        if self.shown_x and self.shown_y:
            continuous_x = self.data_domain[self.shown_x].is_continuous
            continuous_y = self.data_domain[self.shown_y].is_continuous
        return discrete_color and continuous_x and continuous_y

    def should_draw_density(self):
        return self.class_density and self.n_points > 1 and self.can_draw_density()

    def set_labels(self, axis, labels):
        axis = self.plot_widget.getAxis(axis)
        if labels:
            ticks = [[(i, labels[i]) for i in range(len(labels))]]
            axis.setTicks(ticks)
        else:
            axis.setTicks(None)

    def set_axis_title(self, axis, title):
        self.plot_widget.setLabel(axis=axis, text=title)

    def get_size_index(self):
        size_index = -1
        attr_size = self.attr_size
        if attr_size != "" and attr_size != "(Same size)":
            size_index = self.attribute_name_index[attr_size]
        return size_index

    def compute_sizes(self):
        size_index = self.get_size_index()
        if size_index == -1:
            size_data = np.full((self.n_points,), self.point_width)
        else:
            size_data = \
                self.MinShapeSize + \
                self.no_jittering_scaled_data[size_index, self.valid_data] * self.point_width
        size_data[np.isnan(size_data)] = self.MinShapeSize - 2
        return size_data

    def update_sizes(self):
        if self.scatterplot_item:
            size_data = self.compute_sizes()
            self.scatterplot_item.setSize(size_data)
            self.scatterplot_item_sel.setSize(size_data + SELECTION_WIDTH)

    update_point_size = update_sizes

    def get_color_index(self):
        color_index = -1
        attr_color = self.attr_color
        if attr_color != "" and attr_color != "(Same color)":
            color_index = self.attribute_name_index[attr_color]
            color_var = self.data_domain[attr_color]
            if color_var.is_discrete:
                self.discrete_palette.set_number_of_colors(
                    len(color_var.values))
        return color_index

    def compute_colors_sel(self, keep_colors=False):
        if not keep_colors:
            self.pen_colors_sel = self.brush_colors_sel = None

        def make_pen(color, width):
            p = QPen(color, width)
            p.setCosmetic(True)
            return p

        pens = [ QPen(Qt.NoPen),
                 make_pen(QColor(255, 190, 0, 255), SELECTION_WIDTH + 1.) ]
        if self.selection is not None:
            pen = [pens[a] for a in self.selection[self.valid_data]]
        else:
            pen = [pens[0]] * self.n_points
        brush = [QBrush(QColor(255, 255, 255, 0))] * self.n_points
        return pen, brush

    def compute_colors(self, keep_colors=False):
        if not keep_colors:
            self.pen_colors = self.brush_colors = None
        color_index = self.get_color_index()

        def make_pen(color, width):
            p = QPen(color, width)
            p.setCosmetic(True)
            return p

        subset = None
        if self.subset_indices:
            subset = np.array([ ex.id in self.subset_indices
                for ex in self.raw_data[self.valid_data] ])

        if color_index == -1: #color = "Same color"
            color = self.plot_widget.palette().color(OWPalette.Data)
            pen = [make_pen(color, 1.5)] * self.n_points
            if subset is not None:
                brush = [(QBrush(QColor(128, 128, 128, 0)),
                          QBrush(QColor(128, 128, 128, self.alpha_value)))[s]
                         for s in subset]
            else:
                brush = [QBrush(QColor(128, 128, 128))] * self.n_points
            return pen, brush

        c_data = self.original_data[color_index, self.valid_data]
        if self.data_domain[color_index].is_continuous:
            if self.pen_colors is None:
                self.scale = DiscretizedScale(np.nanmin(c_data), np.nanmax(c_data))
                c_data -= self.scale.offset
                c_data /= self.scale.width
                c_data = np.floor(c_data) + 0.5
                c_data /= self.scale.bins
                c_data = np.clip(c_data, 0, 1)
                palette = self.continuous_palette
                self.pen_colors = palette.getRGB(c_data)
                self.brush_colors = np.hstack(
                    [self.pen_colors,
                     np.full((self.n_points, 1), self.alpha_value)])
                self.pen_colors *= 100 / self.DarkerValue
                self.pen_colors = [make_pen(QColor(*col), 1.5)
                                   for col in self.pen_colors.tolist()]
            if subset is not None:
                self.brush_colors[:, 3] = 0
                self.brush_colors[subset, 3] = self.alpha_value
            else:
                self.brush_colors[:, 3] = self.alpha_value
            pen = self.pen_colors
            brush = np.array([QBrush(QColor(*col))
                              for col in self.brush_colors.tolist()])
        else:
            if self.pen_colors is None:
                palette = self.discrete_palette
                n_colors = palette.number_of_colors
                c_data = c_data.copy()
                c_data[np.isnan(c_data)] = n_colors
                c_data = c_data.astype(int)
                colors = np.r_[palette.getRGB(np.arange(n_colors)),
                               [[128, 128, 128]]]
                pens = np.array(
                    [make_pen(QColor(*col).darker(self.DarkerValue), 1.5)
                     for col in colors])
                self.pen_colors = pens[c_data]
                self.brush_colors = np.array([
                    [QBrush(QColor(0, 0, 0, 0)),
                     QBrush(QColor(col[0], col[1], col[2], self.alpha_value))]
                    for col in colors])
                self.brush_colors = self.brush_colors[c_data]
            if subset is not None:
                brush = np.where(
                    subset,
                    self.brush_colors[:, 1], self.brush_colors[:, 0])
            else:
                brush = self.brush_colors[:, 1]
            pen = self.pen_colors
        return pen, brush

    def update_colors(self, keep_colors=False):
        if self.scatterplot_item:
            pen_data, brush_data = self.compute_colors(keep_colors)
            pen_data_sel, brush_data_sel = self.compute_colors_sel(keep_colors)
            self.scatterplot_item.setPen(pen_data, update=False, mask=None)
            self.scatterplot_item.setBrush(brush_data, mask=None)
            self.scatterplot_item_sel.setPen(pen_data_sel, update=False, mask=None)
            self.scatterplot_item_sel.setBrush(brush_data_sel, mask=None)
            if not keep_colors:
                self.make_legend()

                if self.should_draw_density():
                    self.update_data(self.shown_x, self.shown_y)
                elif self.density_img:
                    self.plot_widget.removeItem(self.density_img)

    update_alpha_value = update_colors

    def create_labels(self):
        for x, y in zip(*self.scatterplot_item.getData()):
            ti = TextItem()
            self.plot_widget.addItem(ti)
            ti.setPos(x, y)
            self.labels.append(ti)

    def update_labels(self):
        if not self.attr_label:
            for label in self.labels:
                label.setText("")
            return
        if not self.labels:
            self.create_labels()
        label_column = self.raw_data.get_column_view(self.attr_label)[0]
        formatter = self.raw_data.domain[self.attr_label].str_val
        label_data = map(formatter, label_column)
        black = pg.mkColor(0, 0, 0)
        for label, text in zip(self.labels, label_data):
            label.setText(text, black)

    def get_shape_index(self):
        shape_index = -1
        attr_shape = self.attr_shape
        if attr_shape and attr_shape != "(Same shape)" and \
                len(self.data_domain[attr_shape].values) <= \
                len(self.CurveSymbols):
            shape_index = self.attribute_name_index[attr_shape]
        return shape_index

    def compute_symbols(self):
        shape_index = self.get_shape_index()
        if shape_index == -1:
            shape_data = self.CurveSymbols[np.zeros(self.n_points, dtype=int)]
        else:
            shape_data = self.original_data[shape_index, self.valid_data]
            shape_data[np.isnan(shape_data)] = len(self.CurveSymbols) - 1
            shape_data = self.CurveSymbols[shape_data.astype(int)]
        return shape_data

    def update_shapes(self):
        if self.scatterplot_item:
            shape_data = self.compute_symbols()
            self.scatterplot_item.setSymbol(shape_data)
        self.make_legend()

    def update_grid(self):
        self.plot_widget.showGrid(x=self.show_grid, y=self.show_grid)

    def update_legend(self):
        if self.legend:
            self.legend.setVisible(self.show_legend)

    def create_legend(self):
        self.legend = LegendItem()
        self.legend.setParentItem(self.plot_widget.getViewBox())
        self.legend.anchor(*self.__legend_anchor)

    def remove_legend(self):
        if self.legend:
            anchor = legend_anchor_pos(self.legend)
            if anchor is not None:
                self.__legend_anchor = anchor
            self.legend.setParent(None)
            self.legend = None
        if self.color_legend:
            anchor = legend_anchor_pos(self.color_legend)
            if anchor is not None:
                self.__color_legend_anchor = anchor
            self.color_legend.setParent(None)
            self.color_legend = None

    def make_legend(self):
        self.remove_legend()
        self.make_color_legend()
        self.make_shape_legend()
        self.update_legend()

    def make_color_legend(self):
        color_index = self.get_color_index()
        if color_index == -1:
            return
        color_var = self.data_domain[color_index]
        use_shape = self.get_shape_index() == color_index
        if color_var.is_discrete:
            if not self.legend:
                self.create_legend()
            palette = self.discrete_palette
            for i, value in enumerate(color_var.values):
                color = QColor(*palette.getRGB(i))
                brush = color.lighter(self.DarkerValue)
                self.legend.addItem(
                    ScatterPlotItem(
                        pen=color, brush=brush, size=10,
                        symbol=self.CurveSymbols[i] if use_shape else "o"),
                    escape(value))
        else:
            legend = self.color_legend = LegendItem()
            legend.setParentItem(self.plot_widget.getViewBox())
            legend.anchor(*self.__color_legend_anchor)

            label = PaletteItemSample(self.continuous_palette, self.scale)
            legend.addItem(label, "")
            legend.setGeometry(label.boundingRect())

    def make_shape_legend(self):
        shape_index = self.get_shape_index()
        if shape_index == -1 or shape_index == self.get_color_index():
            return
        if not self.legend:
            self.create_legend()
        shape_var = self.data_domain[shape_index]
        color = self.plot_widget.palette().color(OWPalette.Data)
        pen = QPen(color.darker(self.DarkerValue))
        color.setAlpha(self.alpha_value)
        for i, value in enumerate(shape_var.values):
            self.legend.addItem(
                ScatterPlotItem(pen=pen, brush=color, size=10,
                                symbol=self.CurveSymbols[i]), escape(value))

    def propagate_region_of_interest(self):
        """
        Checks whether the graph shows new data. E.g. whether the user
        selected new attributes or zoomed in. Communicate this to the widget
        so the widget can pull data for the visualization at hand.

        TODO:
        - Add this function to the proper hooks. Currently it is only
          in mouseMoved.
        - Better support for caching the shown_data, this is a bit of a hack.
        """
        if not hasattr(self, 'shown_data'):
            self.shown_data = {}

        axis_bottom = self.plot_widget.getAxis('bottom')
        axis_left = self.plot_widget.getAxis('left')
        shown_data_new = {
            self.shown_x: axis_bottom.range,
            self.shown_y: axis_left.range,
        }
        if not shown_data_new == self.shown_data:
            print("New data shown", shown_data_new)
            self.shown_data = shown_data_new
            self.master.set_region_of_interest(self.shown_data)
            #self.scatterWidget.data.set_region_of_interest(self.shown_data)

        #print(type(axis_bottom))
        #<class 'pyqtgraph.graphicsItems.AxisItem.AxisItem'>
        #print(type(self.plot))
        #<class 'pyqtgraph.graphicsItems.PlotItem.PlotItem.PlotItem'>


    # noinspection PyPep8Naming
    def mouseMoved(self, pos):
        # Propagate the region_of_interest to the widget.
        self.propagate_region_of_interest()
		# Removed things that were removed by biolab.

    def zoom_button_clicked(self):
        self.plot_widget.getViewBox().setMouseMode(
            self.plot_widget.getViewBox().RectMode)

    def pan_button_clicked(self):
        self.plot_widget.getViewBox().setMouseMode(
            self.plot_widget.getViewBox().PanMode)

    def select_button_clicked(self):
        self.plot_widget.getViewBox().setMouseMode(
            self.plot_widget.getViewBox().RectMode)

    def reset_button_clicked(self):
        self.update_data(self.shown_x, self.shown_y, reset_view=True)  # also redraw density image
        # self.view_box.autoRange()

    def select_by_click(self, _, points):
        if self.scatterplot_item is not None:
            self.select(points)

    def select_by_rectangle(self, value_rect):
        if self.scatterplot_item is not None:
            points = [point
                      for point in self.scatterplot_item.points()
                      if value_rect.contains(QPointF(point.pos()))]
            self.select(points)

    def unselect_all(self):
        self.selection = None
        self.update_colors(keep_colors=True)
        self.master.selection_changed()

    def select(self, points):
        # noinspection PyArgumentList
        if self.raw_data is None:
            return
        keys = QApplication.keyboardModifiers()
        if self.selection is None or not keys & (
                        Qt.ShiftModifier + Qt.ControlModifier + Qt.AltModifier):
            self.selection = np.full(len(self.raw_data), False, dtype=np.bool)
        indices = [p.data() for p in points]
        if keys & Qt.AltModifier:
            self.selection[indices] = False
        elif keys & Qt.ControlModifier:
            self.selection[indices] = ~self.selection[indices]
        else:  # Handle shift and no modifiers
            self.selection[indices] = True
        self.update_colors(keep_colors=True)
        self.master.selection_changed()

    def get_selection(self):
        if self.selection is None:
            return np.array([], dtype=int)
        else:
            return np.flatnonzero(self.selection)

    def set_palette(self, p):
        self.plot_widget.setPalette(p)

    def save_to_file(self, size):
        pass

    def help_event(self, event):
        if self.scatterplot_item is None:
            return False

        act_pos = self.scatterplot_item.mapFromScene(event.scenePos())
        points = self.scatterplot_item.pointsAt(act_pos)
        text = ""
        if len(points):
            for i, p in enumerate(points):
                index = p.data()
                text += "Attributes:\n"
                if self.tooltip_shows_all and \
                        len(self.data_domain.attributes) < 30:
                    text += "".join(
                        '   {} = {}\n'.format(attr.name,
                                              self.raw_data[index][attr])
                        for attr in self.data_domain.attributes)
                else:
                    text += '   {} = {}\n   {} = {}\n'.format(
                        self.shown_x, self.raw_data[index][self.shown_x],
                        self.shown_y, self.raw_data[index][self.shown_y])
                    if self.tooltip_shows_all:
                        text += "   ... and {} others\n\n".format(
                            len(self.data_domain.attributes) - 2)
                if self.data_domain.class_var:
                    text += 'Class:\n   {} = {}\n'.format(
                        self.data_domain.class_var.name,
                        self.raw_data[index][self.raw_data.domain.class_var])
                if i < len(points) - 1:
                    text += '------------------\n'

            text = ('<span style="white-space:pre">{}</span>'
                    .format(escape(text)))

            QToolTip.showText(event.screenPos(), text, widget=self.plot_widget)
            return True
        else:
            return False