예제 #1
0
 def get_target_color_type(self):
     """
     Returns the computed target color type if fitted or specified by the user.
     """
     if self._target_color_type is None:
         raise NotFitted("unknown target color type on unfitted visualizer")
     return self._target_color_type
예제 #2
0
    def draw(self, X, y=None):
        """
        Draws the points described by X and colored by the points in y. Can be
        called multiple times before finalize to add more scatter plots to the
        axes, however ``fit()`` must be called before use.

        Parameters
        ----------
        X : array-like of shape (n, 2)
            The matrix produced by the ``transform()`` method.

        y : array-like of shape (n,), optional
            The target, used to specify the colors of the points.

        Returns
        -------
        self.ax : matplotlib Axes object
            Returns the axes that the scatter plot was drawn on.
        """
        scatter_kwargs = {"alpha": self.alpha}

        # Determine the colors
        if self._target_color_type == SINGLE:
            scatter_kwargs["c"] = "b"

        elif self._target_color_type == DISCRETE:
            if y is None:
                raise YellowbrickValueError(
                    "y is required for discrete target")

            scatter_kwargs["c"] = [
                self._colors[np.searchsorted(self.classes_, (yi))] for yi in y
            ]

        elif self._target_color_type == CONTINUOUS:
            if y is None:
                raise YellowbrickValueError(
                    "y is required for continuous target")

            # TODO manually make colorbar so we can draw it in finalize
            scatter_kwargs["c"] = y
            scatter_kwargs["cmap"] = self.colors or palettes.DEFAULT_SEQUENCE

        else:
            # Technically this should never be raised
            raise NotFitted("could not determine target color type")

        # Draw the scatter plot with the associated colors and alpha
        self._scatter = self.ax.scatter(X[:, 0], X[:, 1], **scatter_kwargs)
        return self.ax
예제 #3
0
    def score(self, X, y):
        """
        The score function is the hook for visual interaction. Pass in test
        data and the visualizer will create predictions on the data and
        evaluate them with respect to the test values. The evaluation will
        then be passed to draw() and the result of the estimator score will
        be returned.

        Parameters
        ----------
        X : array-like
            X (also X_test) are the dependent variables of test set to predict

        y : array-like
            y (also y_test) is the independent actual variables to score against

        Returns
        -------
        score : float
            Returns the score of the underlying model, usually accuracy for
            classification models. Refer to the specific model for more details.
        """
        # If the estimator has been passed in fitted but the visualizer was not fit
        # then we can retrieve the classes from the estimator, unfortunately we cannot
        # retrieve the class counts so we simply set them to None and warn the user.
        # NOTE: cannot test if hasattr(self, "classes_") because it will be proxied.
        if not hasattr(self, "class_counts_"):
            if not hasattr(self.estimator, "classes_"):
                raise NotFitted(
                    (
                        "could not determine required property classes_; "
                        "the visualizer must either be fit or instantiated with a "
                        "fitted classifier before calling score()"
                    )
                )

            self.class_counts_ = None
            self.classes_ = self._decode_labels(self.estimator.classes_)
            warnings.warn(
                "could not determine class_counts_ from previously fitted classifier",
                YellowbrickWarning,
            )

        # This method implements ScoreVisualizer (do not call super).
        self.score_ = self.estimator.score(X, y)
        return self.score_
예제 #4
0
    def _determine_scatter_kwargs(self, y=None):
        """
        Determines scatter argumnets to pass into ``plt.scatter()``. If y is
        discrete or single then determine colors. If continuous then determine
        colors and colormap.Also normalize to range

        Parameters
        ----------
        y : array-like of shape (n,), optional
            The target, used to specify the colors of the points for continuous
            target.
        """

        scatter_kwargs = {"alpha": self.alpha}
        # Determine the colors
        if self._target_color_type == TargetType.SINGLE:
            scatter_kwargs["c"] = self._colors

        elif self._target_color_type == TargetType.DISCRETE:
            if y is None:
                raise YellowbrickValueError(
                    "y is required for discrete target")

            try:
                scatter_kwargs["c"] = [
                    self._colors[self.classes_[yi]] for yi in y
                ]
            except IndexError:
                raise YellowbrickValueError(
                    "Target needs to be label encoded.")

        elif self._target_color_type == TargetType.CONTINUOUS:
            if y is None:
                raise YellowbrickValueError(
                    "y is required for continuous target")

            scatter_kwargs["c"] = y
            scatter_kwargs["cmap"] = self._colors
            self._norm = mpl.colors.Normalize(vmin=self.range_[0],
                                              vmax=self.range_[1])

        else:
            # Technically this should never be raised
            raise NotFitted("could not determine target color type")
        return scatter_kwargs
예제 #5
0
    def get_colors(self, y):
        """
        Returns the color for the specified value(s) of y based on the learned
        colors property for any specified target type.

        Parameters
        ----------
        y : array-like
            The values of y to get the associated colors for.

        Returns
        -------
        colors : list
            Returns a list of colors for each value in y.
        """
        if self._colors is None:
            raise NotFitted("cannot determine colors on unfitted visualizer")

        if self._target_color_type == TargetType.SINGLE:
            return [self._colors] * len(y)

        if self._target_color_type == TargetType.DISCRETE:
            try:
                # Use the label encoder to get the class name (or use the value
                # if the label is not mapped in the encoder) then use the class
                # name to get the color from the color map.
                return [
                    self._colors[self._label_encoder.get(yi, yi)] for yi in y
                ]
            except KeyError:
                unknown = set(y) - set(self._label_encoder.keys())
                unknown = ", ".join(["'{}'".format(uk) for uk in unknown])
                raise YellowbrickKeyError(
                    "could not determine color for classes {}".format(unknown))

        if self._target_color_type == TargetType.CONTINUOUS:
            # Normalize values into target range and compute colors from colormap
            norm = Normalize(*self.range_)
            return self._colors(norm(y))

        # This is a developer error, we should never get here!
        raise YellowbrickValueError("unknown target color type '{}'".format(
            self._target_color_type))
예제 #6
0
    def draw(self, **kwargs):
        """
        Draws the feature importances as a bar chart; called from fit.
        """
        # Quick validation
        for param in ('feature_importances_', 'features_'):
            if not hasattr(self, param):
                raise NotFitted("missing required param '{}'".format(param))

        # Find the positions for each bar
        pos = np.arange(self.features_.shape[0]) + 0.5

        # Plot the bar chart
        self.ax.barh(pos, self.feature_importances_, align='center')

        # Set the labels for the bars
        self.ax.set_yticks(pos)
        self.ax.set_yticklabels(self.features_)

        return self.ax
예제 #7
0
    def class_colors_(self):
        """
        Returns ``_colors`` if it exists, otherwise computes a categorical color
        per class based on the matplotlib color cycle. If the visualizer is not
        fitted, raises a NotFitted exception.

        If subclasses require users to choose colors or have specialized color
        handling, they should set ``_colors`` on init or during fit.

        Notes
        -----
        Because this is a property, this docstring is for developers only.
        """
        if not hasattr(self, "_colors"):
            if not hasattr(self, "classes_"):
                raise NotFitted("cannot determine colors before fit")

            # TODO: replace with resolve_colors
            self._colors = color_palette(None, len(self.classes_))
        return self._colors
예제 #8
0
    def draw(self, **kwargs):
        """
        Draws the feature importances as a bar chart; called from fit.
        """
        # Quick validation
        for param in ("feature_importances_", "features_"):
            if not hasattr(self, param):
                raise NotFitted("missing required param '{}'".format(param))

        # Find the positions for each bar
        pos = np.arange(self.features_.shape[0]) + 0.5

        # Plot the bar chart
        if self.stack:
            colors = resolve_colors(len(self.classes_), colormap=self.colormap)
            legend_kws = {"bbox_to_anchor": (1.04, 0.5), "loc": "center left"}
            bar_stack(
                self.feature_importances_,
                ax=self.ax,
                labels=list(self.classes_),
                ticks=self.features_,
                orientation="h",
                colors=colors,
                legend_kws=legend_kws,
            )
        else:
            colors = resolve_colors(len(self.features_),
                                    colormap=self.colormap,
                                    colors=self.colors)
            self.ax.barh(pos,
                         self.feature_importances_,
                         color=colors,
                         align="center")

            # Set the labels for the bars
            self.ax.set_yticks(pos)
            self.ax.set_yticklabels(self.features_)

        return self.ax