コード例 #1
0
 def ClassPredictionErrorViz(self):
     y_type, y_true, y_pred = _check_targets(self.y_true, self.y_pred)
     if y_type not in ("binary", "multiclass"):
         raise YellowbrickValueError("{} is not supported".format(y_type))
     # Get the indices of the unique labels
     indices = unique_labels(self.y_true, self.y_pred)
     labels = self.classes
     predictions_ = np.array([[
         (self.y_pred[self.y_true == label_t] == label_p).sum()
         for label_p in indices
     ] for label_t in indices])
     fig, ax = plt.subplots(ncols=1, nrows=1)
     legend_kws = {"bbox_to_anchor": (1.04, 0.5), "loc": "center left"}
     bar_stack(
         predictions_,
         ax,
         labels=list(self.classes),
         ticks=self.classes,
         legend_kws=legend_kws,
     )
     # Set the title
     ax.set_title("Class Prediction Error for {}".format(self.name))
     # Set the axes labels
     ax.set_xlabel("Actual Class")
     ax.set_ylabel("Number of Predicted Class")
     # Compute the ceiling for the y limit
     cmax = max([sum(predictions) for predictions in predictions_])
     ax.set_ylim(0, cmax + cmax * 0.1)
     # Ensure the legend fits on the figure
     fig.tight_layout(rect=[0, 0, 0.90, 1])
     fig.savefig(self.path_to_save + "/ClassPredictionError_" + self.name +
                 ".pdf")
     return ax
コード例 #2
0
ファイル: postag.py プロジェクト: chrinide/yellowbrick
    def draw(self, **kwargs):
        """
        Called from the fit method, this method creates the canvas and
        draws the part-of-speech tag mapping as a bar chart.

        Parameters
        ----------
        kwargs: dict
            generic keyword arguments.

        Returns
        -------
        ax : matplotlib axes
            Axes on which the PosTagVisualizer was drawn.
        """
        # Converts nested dict to nested list
        pos_tag_counts = np.array(
            [list(i.values()) for i in self.pos_tag_counts_.values()])
        # stores sum of nested list column wise
        pos_tag_sum = np.sum(pos_tag_counts, axis=0)

        if self.frequency:
            # sorts the count and tags by sum for frequency true
            idx = (pos_tag_sum).argsort()[::-1]
            self._pos_tags = np.array(self._pos_tags)[idx]
            pos_tag_counts = pos_tag_counts[:, idx]

        if self.stack:
            bar_stack(
                pos_tag_counts,
                ax=self.ax,
                labels=list(self.labels_),
                ticks=self._pos_tags,
                colors=self.colors,
                colormap=self.colormap,
            )
        else:
            xidx = np.arange(len(self._pos_tags))
            colors = resolve_colors(n_colors=len(self._pos_tags),
                                    colormap=self.colormap,
                                    colors=self.colors)
            self.ax.bar(xidx, pos_tag_counts[0], color=colors)

        return self.ax
コード例 #3
0
    def draw(self, **kwargs):
        """
        Draws the feature importances as a bar chart; called from fit.
        """
        # Quick validation
        for param in ("feature_importances_", "features_"):
            if not hasattr(self, param):
                raise NotFitted("missing required param '{}'".format(param))

        # Find the positions for each bar
        pos = np.arange(self.features_.shape[0]) + 0.5

        # Plot the bar chart
        if self.stack:
            colors = resolve_colors(len(self.classes_), colormap=self.colormap)
            legend_kws = {"bbox_to_anchor": (1.04, 0.5), "loc": "center left"}
            bar_stack(
                self.feature_importances_,
                ax=self.ax,
                labels=list(self.classes_),
                ticks=self.features_,
                orientation="h",
                colors=colors,
                legend_kws=legend_kws,
            )
        else:
            colors = resolve_colors(len(self.features_),
                                    colormap=self.colormap,
                                    colors=self.colors)
            self.ax.barh(pos,
                         self.feature_importances_,
                         color=colors,
                         align="center")

            # Set the labels for the bars
            self.ax.set_yticks(pos)
            self.ax.set_yticklabels(self.features_)

        return self.ax
コード例 #4
0
    def draw(self):
        """
        Renders the class prediction error across the axis.

        Returns
        -------
        ax : Matplotlib Axes
            The axes on which the figure is plotted
        """

        if not hasattr(self, "predictions_") or not hasattr(self, "classes_"):
            raise NotFitted.from_estimator(self, "draw")

        legend_kws = {"bbox_to_anchor": (1.04, 0.5), "loc": "center left"}
        bar_stack(
            self.predictions_,
            self.ax,
            labels=list(self.classes_),
            ticks=self.classes_,
            colors=self.class_colors_,
            legend_kws=legend_kws,
        )
        return self.ax