def test_release(self, mocker):
     fs_mock = mocker.patch("fsspec.filesystem").return_value
     data_set = MatplotlibWriter(filepath=FULL_PATH)
     data_set.release()
     fs_mock.invalidate_cache.assert_called_once_with(
         "{}/{}".format(BUCKET_NAME, KEY_PATH)
     )
Exemple #2
0
def evaluate_LightGBM_model(regressor: lgb.basic.Booster, X_test: pd.DataFrame,
                            parameters: Dict) -> pd.DataFrame:
    is_train = parameters['isTrain']
    output_id = X_test[parameters['id_name']]
    #X_test.drop(target_name, axis=1, inplace=True)
    y_pred = regressor.predict(X_test, num_iteration=regressor.best_iteration)
    print('y predicted on LightGBM!')
    if is_train:
        y_test = X_test[parameters['target']]

        print(type(y_pred))
        #y_pred = np.argmax(y_pred, axis=1)
        #roc_curve = r
        fpr, tpr, _ = roc_curve(y_test, y_pred)
        plt.plot([0, 1], [0, 1], linestyle='--', label='No Skill')
        plt.plot(fpr, tpr, marker='.', label='LGBM')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.legend(loc='lower right')

        score = roc_auc_score(y_test, y_pred)
        output_date = datetime.date.today()
        filepath_ = 'data/07_model_output/ROC_plot_LGBM' + str(
            output_date) + '.png'
        single_plot_writer = MatplotlibWriter(filepath=filepath_)
        single_plot_writer.save(plt)
        plt.clf()

        logger = logging.getLogger(__name__)
        logger.info('LightGBM AUC is %.3f.', score)

    output = pd.DataFrame({'ID': output_id, 'y_pred': y_pred})
    return output
Exemple #3
0
def plot_multiple_line_box(pl_perfs, al_perfs, bs, budget, n_init):
    plots_dict = dict()
    for b in bs:
        plots_dict[b] = plot_line_box(pl_perfs, al_perfs, bs, budget, b)
        plt.close()
    #img.savefig('data/08_reporting/multi_line_box.png')
    plot_writer = MatplotlibWriter(filepath="data/08_reporting/multi_line_box")
    plot_writer.save(plots_dict)
    return None
 def test_versioning_existing_dataset_dict_plot(self, plot_writer,
                                                versioned_plot_writer,
                                                mock_dict_plot):
     """Check the behavior when attempting to save a versioned dataset on top of an
     already existing (non-versioned) dataset, using a dict of plots. Note: because
     a dict of plots saves to a directory, an error is not expected."""
     plot_writer = MatplotlibWriter(
         filepath=versioned_plot_writer._filepath.as_posix())
     plot_writer.save(mock_dict_plot)
     assert plot_writer.exists()
     versioned_plot_writer.save(mock_dict_plot)
     assert versioned_plot_writer.exists()
    def test_version_str_repr(self, load_version, save_version):
        """Test that version is in string representation of the class instance
        when applicable."""
        filepath = "chart.png"
        chart = MatplotlibWriter(filepath=filepath)
        chart_versioned = MatplotlibWriter(filepath=filepath,
                                           version=Version(
                                               load_version, save_version))
        assert filepath in str(chart)
        assert "version" not in str(chart)

        assert filepath in str(chart_versioned)
        ver_str = f"version=Version(load={load_version}, save='{save_version}')"
        assert ver_str in str(chart_versioned)
def test_bad_credentials(mock_dict_plot):
    """Test writing with bad credentials"""
    bad_writer = MatplotlibWriter(
        filepath=FULL_PATH,
        credentials={
            "client_kwargs": {
                "aws_access_key_id": "not_for_testing",
                "aws_secret_access_key": "definitely_not_for_testing",
            }
        },
    )

    pattern = r"The AWS Access Key Id you provided does not exist in our records"
    with pytest.raises(DataSetError, match=pattern):
        bad_writer.save(mock_dict_plot)
def plot_writer(mocked_s3_bucket, fs_args, save_args):  # pylint: disable=unused-argument
    return MatplotlibWriter(
        filepath=FULL_PATH,
        credentials=AWS_CREDENTIALS,
        fs_args=fs_args,
        save_args=save_args,
    )
    def test_http_filesystem_no_versioning(self):
        pattern = r"HTTP\(s\) DataSet doesn't support versioning\."

        with pytest.raises(DataSetError, match=pattern):
            MatplotlibWriter(
                filepath="https://example.com/file.png", version=Version(None, None)
            )
    def test_fs_args(self, tmp_path, mock_single_plot, mocked_encrypted_s3_bucket):
        """Test writing to encrypted bucket"""
        normal_encryped_writer = MatplotlibWriter(
            fs_args={"s3_additional_kwargs": {"ServerSideEncryption": "AES256"}},
            filepath=FULL_PATH,
            credentials=AWS_CREDENTIALS,
        )

        normal_encryped_writer.save(mock_single_plot)

        download_path = tmp_path / "downloaded_image.png"
        actual_filepath = tmp_path / "locally_saved.png"

        mock_single_plot.savefig(str(actual_filepath))

        mocked_encrypted_s3_bucket.download_file(
            BUCKET_NAME, KEY_PATH, str(download_path)
        )

        assert actual_filepath.read_bytes() == download_path.read_bytes()
    def test_versioning_existing_dataset_single_plot(self, plot_writer,
                                                     versioned_plot_writer,
                                                     mock_single_plot):
        """Check the error when attempting to save a versioned dataset on top of an
        already existing (non-versioned) dataset, using a single plot."""

        plot_writer = MatplotlibWriter(
            filepath=versioned_plot_writer._filepath.as_posix())
        plot_writer.save(mock_single_plot)
        assert plot_writer.exists()
        pattern = (
            f"(?=.*file with the same name already exists in the directory)"
            f"(?=.*{versioned_plot_writer._filepath.parent.as_posix()})")
        with pytest.raises(DataSetError, match=pattern):
            versioned_plot_writer.save(mock_single_plot)

        # Remove non-versioned dataset and try again
        Path(plot_writer._filepath.as_posix()).unlink()
        versioned_plot_writer.save(mock_single_plot)
        assert versioned_plot_writer.exists()
 def test_ineffective_overwrite(self, load_version, save_version):
     pattern = ("Setting `overwrite=True` is ineffective if versioning "
                "is enabled, since the versioned path must not already "
                "exist; overriding flag with `overwrite=False` instead.")
     with pytest.warns(UserWarning, match=pattern):
         versioned_plot_writer = MatplotlibWriter(
             filepath="/tmp/file.txt",
             version=Version(load_version, save_version),
             overwrite=True,
         )
     assert not versioned_plot_writer._overwrite
Exemple #12
0
def evaluate_XGBoost_model(regressor: xgb.core.Booster, X_test: pd.DataFrame,
                           parameters: Dict) -> pd.DataFrame:
    #X_test = X_test.values
    #print(regressor.feature_names)
    target_name = parameters['target']
    output_id = parameters['id_name']
    use_features = regressor.feature_names
    is_train = parameters['isTrain']
    xgb_test = xgb.DMatrix(X_test[use_features],
                           feature_names=regressor.feature_names)
    y_pred = regressor.predict(xgb_test,
                               ntree_limit=regressor.best_ntree_limit)
    print('y predicted on XGBoost!')
    if is_train:
        y_test = X_test[target_name]
        print(type(y_pred))
        fpr, tpr, _ = roc_curve(y_test, y_pred)
        plt.plot([0, 1], [0, 1], linestyle='--', label='No Skill')
        plt.plot(fpr, tpr, marker='.', label='XGBM')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.legend(loc='lower right')

        score = roc_auc_score(y_test, y_pred)
        output_date = datetime.date.today()
        filepath_ = 'data/07_model_output/ROC_plot_XGB' + str(
            output_date) + '.png'
        single_plot_writer = MatplotlibWriter(filepath=filepath_)
        single_plot_writer.save(plt)
        plt.clf()

        #y_pred = np.argmax(y_pred, axis=1)
        #roc_curve = r
        score = roc_auc_score(y_test, y_pred)
        logger = logging.getLogger(__name__)
        logger.info('XGBoost AUC is %.3f.', score)

    output = pd.DataFrame({'ID': output_id, 'y_pred': y_pred})
    return output
def versioned_plot_writer(tmp_path, load_version, save_version):
    filepath = (tmp_path / "matplotlib.png").as_posix()
    return MatplotlibWriter(filepath=filepath,
                            version=Version(load_version, save_version))
def plot_writer(mocked_s3_bucket):  # pylint: disable=unused-argument
    return MatplotlibWriter(filepath=FULL_PATH, credentials=CREDENTIALS)