示例#1
0
    def test_score(self):
        """
        Assert returns R2 score
        """
        visualizer = ResidualsPlot(Ridge(random_state=8893))

        visualizer.fit(self.data.X.train, self.data.y.train)
        score = visualizer.score(self.data.X.test, self.data.y.test)

        assert score == pytest.approx(0.9999888484, rel=1e-4)
        assert visualizer.train_score_ == pytest.approx(0.9999906, rel=1e-4)
        assert visualizer.test_score_ == score
示例#2
0
 def test_either_hist_or_QQ_plot(self):
     """
     Setting both hist=True and qqplot=True raises exception.
     """
     with pytest.raises(YellowbrickValueError,
                        match="Set either hist or qqplot to False"):
         ResidualsPlot(LinearRegression(), hist=True, qqplot=True)
示例#3
0
    def test_residuals_plot(self):
        """
        Image similarity of residuals plot on random data with OLS
        """
        _, ax = plt.subplots()

        visualizer = ResidualsPlot(LinearRegression(), ax=ax)

        visualizer.fit(self.data.X.train, self.data.y.train)
        visualizer.score(self.data.X.test, self.data.y.test)
        visualizer.finalize()

        self.assert_images_similar(visualizer)
示例#4
0
    def test_hist_matplotlib_version(self, mock_toolkit):
        """
        ValueError is raised when matplotlib version is incorrect and hist=True
        """
        with pytest.raises(ImportError):
            from mpl_toolkits.axes_grid1 import make_axes_locatable

            assert not make_axes_locatable

        with pytest.raises(YellowbrickValueError,
                           match="requires matplotlib 2.0.2"):
            ResidualsPlot(LinearRegression(), hist=True)
示例#5
0
    def test_residuals_plot_no_histogram(self):
        """
        Image similarity test when hist=False
        """
        _, ax = plt.subplots()

        model = MLPRegressor(random_state=19)
        visualizer = ResidualsPlot(model, ax=ax, hist=False)

        visualizer.fit(self.data.X.train, self.data.y.train)
        visualizer.score(self.data.X.test, self.data.y.test)
        visualizer.finalize()

        self.assert_images_similar(visualizer, tol=1, remove_legend=True)
示例#6
0
    def test_no_hist_matplotlib_version(self, mock_toolkit):
        """
        No error is raised when matplotlib version is incorrect and hist=False
        """
        with pytest.raises(ImportError):
            from mpl_toolkits.axes_grid1 import make_axes_locatable

            assert not make_axes_locatable

        try:
            ResidualsPlot(LinearRegression(), hist=False)
        except YellowbrickValueError as e:
            self.fail(e)
示例#7
0
    def test_residuals_plot_numpy(self):
        """
        Test NumPy real world dataset with image similarity on Lasso
        """
        _, ax = plt.subplots()

        # Load the occupancy dataset from fixtures
        data = load_energy(return_dataset=True)
        X, y = data.to_numpy()

        # Create train/test splits
        splits = tts(X, y, test_size=0.2, random_state=231)
        X_train, X_test, y_train, y_test = splits

        visualizer = ResidualsPlot(Lasso(random_state=44), ax=ax)
        visualizer.fit(X_train, y_train)
        visualizer.score(X_test, y_test)
        visualizer.finalize()

        self.assert_images_similar(visualizer, tol=1.5)
示例#8
0
    def test_alpha_param(self, mock_sca):
        """
        Test that the user can supply an alpha param on instantiation
        """
        # Instantiate a prediction error plot, provide custom alpha
        visualizer = ResidualsPlot(Ridge(random_state=8893),
                                   train_alpha=0.3,
                                   test_alpha=0.75,
                                   hist=False)
        alphas = {"train_point": 0.3, "test_point": 0.75}
        # Test param gets set correctly
        assert visualizer.alphas == alphas

        visualizer.ax = mock.MagicMock()
        visualizer.fit(self.data.X.train, self.data.y.train)
        visualizer.score(self.data.X.test, self.data.y.test)

        # Test that alpha was passed to internal matplotlib scatterplot
        _, scatter_kwargs = visualizer.ax.scatter.call_args
        assert "alpha" in scatter_kwargs
        assert scatter_kwargs["alpha"] == 0.75
示例#9
0
    def test_residuals_with_fitted(self):
        """
        Test that ResidualsPlot properly handles an already-fitted model
        """
        X, y = load_energy(return_dataset=True).to_numpy()

        model = Ridge().fit(X, y)

        with mock.patch.object(model, "fit") as mockfit:
            oz = ResidualsPlot(model)
            oz.fit(X, y)
            mockfit.assert_not_called()

        with mock.patch.object(model, "fit") as mockfit:
            oz = ResidualsPlot(model, is_fitted=True)
            oz.fit(X, y)
            mockfit.assert_not_called()

        with mock.patch.object(model, "fit") as mockfit:
            oz = ResidualsPlot(model, is_fitted=False)
            oz.fit(X, y)
            mockfit.assert_called_once_with(X, y)