def test_multiple_loads(self, versioned_csv_data_set, dummy_dataframe, filepath_csv): """Test that if a new version is created mid-run, by an external system, it won't be loaded in the current run.""" versioned_csv_data_set.save(dummy_dataframe) versioned_csv_data_set.load() v1 = versioned_csv_data_set.resolve_load_version() sleep(0.5) # force-drop a newer version into the same location v_new = generate_timestamp() GenericDataSet( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(v_new, v_new), ).save(dummy_dataframe) versioned_csv_data_set.load() v2 = versioned_csv_data_set.resolve_load_version() assert v2 == v1 # v2 should not be v_new! ds_new = GenericDataSet( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(None, None), ) assert (ds_new.resolve_load_version() == v_new ) # new version is discoverable by a new instance
def test_release_instance_cache(self, dummy_dataframe, filepath_csv): """Test that cache invalidation does not affect other instances""" ds_a = GenericDataSet( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(None, None), ) assert ds_a._version_cache.currsize == 0 ds_a.save(dummy_dataframe) # create a version assert ds_a._version_cache.currsize == 2 ds_b = GenericDataSet( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(None, None), ) assert ds_b._version_cache.currsize == 0 ds_b.resolve_save_version() assert ds_b._version_cache.currsize == 1 ds_b.resolve_load_version() assert ds_b._version_cache.currsize == 2 ds_a.release() # dataset A cache is cleared assert ds_a._version_cache.currsize == 0 # dataset B cache is unaffected assert ds_b._version_cache.currsize == 2
def test_multiple_saves(self, dummy_dataframe, filepath_csv): """Test multiple cycles of save followed by load for the same dataset""" ds_versioned = GenericDataSet( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(None, None), ) # first save ds_versioned.save(dummy_dataframe) first_save_version = ds_versioned.resolve_save_version() first_load_version = ds_versioned.resolve_load_version() assert first_load_version == first_save_version # second save sleep(0.5) ds_versioned.save(dummy_dataframe) second_save_version = ds_versioned.resolve_save_version() second_load_version = ds_versioned.resolve_load_version() assert second_load_version == second_save_version assert second_load_version > first_load_version # another dataset ds_new = GenericDataSet( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(None, None), ) assert ds_new.resolve_load_version() == second_load_version
def versioned_s3_data_set(load_version, save_version): return PickleS3DataSet( filepath=FILENAME, bucket_name=BUCKET_NAME, credentials=AWS_CREDENTIALS, version=Version(load_version, save_version), )
def test_http_filesystem_no_versioning(self): pattern = r"HTTP\(s\) DataSet doesn't support versioning\." with pytest.raises(DataSetError, match=pattern): MatplotlibWriter( filepath="https://example.com/file.png", version=Version(None, None) )
def versioned_csv_data_set(filepath_csv, load_version, save_version): return GenericDataSet( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(load_version, save_version), save_args={"index": False}, )
def test_save_version_warning(self, mocker): exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=exact_version) mocker.patch.object(versioned_hdfs, "_exists_function", return_value=False) mocked_spark_df = mocker.Mock() pattern = (r"Save path `{fn}/{f}/{sv}/{f}` did not match load path " r"`{fn}/{f}/{lv}/{f}` for SparkDataSet\(.+\)".format( fn=FOLDER_NAME, f=FILENAME, sv=exact_version.save, lv=exact_version.load)) with pytest.warns(UserWarning, match=pattern): versioned_hdfs.save(mocked_spark_df) mocked_spark_df.write.save.assert_called_once_with( "hdfs://{fn}/{f}/{sv}/{f}".format(fn=FOLDER_NAME, f=FILENAME, sv=exact_version.save), "parquet", )
def test_load_exact(self, tmp_path, sample_spark_df): ts = generate_timestamp() ds_local = SparkDataSet(filepath=(tmp_path / FILENAME).as_posix(), version=Version(ts, ts)) ds_local.save(sample_spark_df) reloaded = ds_local.load() assert reloaded.exceptAll(sample_spark_df).count() == 0
def test_load_exact(self, tmp_path, sample_spark_df): ts = generate_timestamp() ds_dbfs = SparkDataSet(filepath="/dbfs" + str(tmp_path / FILENAME), version=Version(ts, ts)) ds_dbfs.save(sample_spark_df) reloaded = ds_dbfs.load() assert reloaded.exceptAll(sample_spark_df).count() == 0
def versioned_blob_csv_data_set(load_version, save_version): return CSVBlobDataSet( filepath=TEST_FILE_NAME, container_name=TEST_CONTAINER_NAME, credentials=TEST_CREDENTIALS, blob_to_text_args={"to_extra": 41}, blob_from_text_args={"from_extra": 42}, version=Version(load_version, save_version), )
def versioned_s3_data_set(load_version, save_version): return ParquetS3DataSet( filepath=FILENAME, bucket_name=BUCKET_NAME, credentials={ "aws_access_key_id": "YOUR_KEY", "aws_secret_access_key": "YOUR SECRET", }, version=Version(load_version, save_version), )
def test_ineffective_overwrite(self, load_version, save_version): pattern = ("Setting `overwrite=True` is ineffective if versioning " "is enabled, since the versioned path must not already " "exist; overriding flag with `overwrite=False` instead.") with pytest.warns(UserWarning, match=pattern): versioned_plot_writer = MatplotlibWriter( filepath="/tmp/file.txt", version=Version(load_version, save_version), overwrite=True, ) assert not versioned_plot_writer._overwrite
def test_save_version_warning(self, tmp_path, sample_spark_df): exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") ds_local = SparkDataSet(filepath=str(tmp_path / FILENAME), version=exact_version) pattern = (r"Save version `{ev.save}` did not match load version " r"`{ev.load}` for SparkDataSet\(.+\)".format( ev=exact_version)) with pytest.warns(UserWarning, match=pattern): ds_local.save(sample_spark_df)
def test_load_exact(self, mocker): ts = generate_timestamp() versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=Version(ts, None)) get_spark = mocker.patch.object(versioned_hdfs, "_get_spark") versioned_hdfs.load() get_spark.return_value.read.load.assert_called_once_with( "hdfs://{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, f=FILENAME, v=ts), "parquet", )
def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "chart.png" chart = MatplotlibWriter(filepath=filepath) chart_versioned = MatplotlibWriter(filepath=filepath, version=Version( load_version, save_version)) assert filepath in str(chart) assert "version" not in str(chart) assert filepath in str(chart_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(chart_versioned)
def test_load_exact(self, mocker): ts = generate_timestamp() ds_s3 = SparkDataSet( filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME), version=Version(ts, None), credentials=AWS_CREDENTIALS, ) get_spark = mocker.patch.object(ds_s3, "_get_spark") ds_s3.load() get_spark.return_value.read.load.assert_called_once_with( "s3a://{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v=ts), "parquet")
def versioned_gcs_data_set( load_version, save_version, load_args, save_args, mock_gcs_filesystem, # pylint: disable=unused-argument ): return ParquetGCSDataSet( bucket_name=BUCKET_NAME, filepath=FILENAME, credentials=None, load_args=load_args, save_args=save_args, version=Version(load_version, save_version), project=GCP_PROJECT, )
def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" ds = PickleS3DataSet(filepath=FILENAME, bucket_name=BUCKET_NAME) ds_versioned = PickleS3DataSet( filepath=FILENAME, bucket_name=BUCKET_NAME, version=Version(load_version, save_version), ) assert FILENAME in str(ds) assert "version" not in str(ds) assert FILENAME in str(ds_versioned) ver_str = "version=Version(load={}, save='{}')".format( load_version, save_version) assert ver_str in str(ds_versioned)
def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "test.json" ds = NetworkXDataSet(filepath=filepath) ds_versioned = NetworkXDataSet(filepath=filepath, version=Version(load_version, save_version)) assert filepath in str(ds) assert "version" not in str(ds) assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) assert "NetworkXDataSet" in str(ds_versioned) assert "NetworkXDataSet" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds)
def test_version_str_repr(self, filepath_csv, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = filepath_csv.as_posix() ds = GenericDataSet(filepath=filepath, file_format="csv") ds_versioned = GenericDataSet( filepath=filepath, file_format="csv", version=Version(load_version, save_version), ) assert filepath in str(ds) assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) assert "GenericDataSet" in str(ds_versioned) assert "GenericDataSet" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds)
def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" ds = CSVBlobDataSet( filepath=TEST_FILE_NAME, container_name=TEST_CONTAINER_NAME, credentials=TEST_CREDENTIALS, ) ds_versioned = CSVBlobDataSet( filepath=TEST_FILE_NAME, container_name=TEST_CONTAINER_NAME, credentials=TEST_CREDENTIALS, version=Version(load_version, save_version), ) assert TEST_FILE_NAME in str(ds) assert "version" not in str(ds) assert TEST_FILE_NAME in str(ds_versioned) ver_str = "version=Version(load={}, save='{}')".format( load_version, save_version) assert ver_str in str(ds_versioned)
def test_save_version_warning(self, mocker): exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") ds_s3 = SparkDataSet( filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME), version=exact_version, credentials=AWS_CREDENTIALS, ) mocked_spark_df = mocker.Mock() pattern = (r"Save version `{ev.save}` did not match load version " r"`{ev.load}` for SparkDataSet\(.+\)".format( ev=exact_version)) with pytest.warns(UserWarning, match=pattern): ds_s3.save(mocked_spark_df) mocked_spark_df.write.save.assert_called_once_with( "s3a://{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v=exact_version.save), "parquet", )
def versioned_networkx_data_set(filepath_json, load_version, save_version): return NetworkXDataSet(filepath=filepath_json, version=Version(load_version, save_version))
def version(): load_version = None # use latest save_version = generate_timestamp() # freeze save version return Version(load_version, save_version)
def versioned_plot_writer(tmp_path, load_version, save_version): filepath = (tmp_path / "matplotlib.png").as_posix() return MatplotlibWriter(filepath=filepath, version=Version(load_version, save_version))
def versioned_hv_writer(filepath_png, load_version, save_version): return HoloviewsWriter(filepath_png, version=Version(load_version, save_version))