예제 #1
1
    def test_save_invalidates_cache(self, dataset, local_csvs):
        pds = PartitionedDataSet(str(local_csvs), dataset)
        first_load = pds.load()

        data = pd.DataFrame({"foo": 42, "bar": ["a", "b", None]})
        part_id = "new/data.csv"
        pds.save({part_id: data})
        assert part_id not in first_load
        assert part_id in pds.load()
예제 #2
0
    def test_exists(self, local_csvs, dataset):
        assert PartitionedDataSet(str(local_csvs), dataset).exists()

        empty_folder = local_csvs / "empty" / "folder"
        assert not PartitionedDataSet(str(empty_folder), dataset).exists()
        empty_folder.mkdir(parents=True)
        assert not PartitionedDataSet(str(empty_folder), dataset).exists()
예제 #3
0
    def test_load(self, dataset, mocked_csvs_in_s3, partitioned_data_pandas):
        pds = PartitionedDataSet(mocked_csvs_in_s3, dataset)
        loaded_partitions = pds.load()

        assert loaded_partitions.keys() == partitioned_data_pandas.keys()
        for partition_id, load_func in loaded_partitions.items():
            df = load_func()
            assert_frame_equal(df, partitioned_data_pandas[partition_id])
예제 #4
0
    def test_exists(self, dataset, mocked_csvs_in_s3):
        assert PartitionedDataSet(mocked_csvs_in_s3, dataset).exists()

        empty_folder = "/".join([mocked_csvs_in_s3, "empty", "folder"])
        assert not PartitionedDataSet(empty_folder, dataset).exists()

        s3fs.S3FileSystem().mkdir(empty_folder)
        assert not PartitionedDataSet(empty_folder, dataset).exists()
    def test_overwrite(self, local_csvs, overwrite, expected_num_parts):
        pds = PartitionedDataSet(str(local_csvs),
                                 "pandas.CSVDataSet",
                                 overwrite=overwrite)
        original_data = pd.DataFrame({"foo": 42, "bar": ["a", "b", None]})
        part_id = "new/data"
        pds.save({part_id: original_data})
        loaded_partitions = pds.load()

        assert part_id in loaded_partitions
        assert len(loaded_partitions.keys()) == expected_num_parts
예제 #6
0
    def test_invalid_dataset(self, dataset, local_csvs):
        pds = PartitionedDataSet(str(local_csvs), dataset)
        loaded_partitions = pds.load()

        for partition, df_loader in loaded_partitions.items():
            pattern = r"Failed while loading data from data set ParquetDataSet(.*)"
            with pytest.raises(DataSetError, match=pattern) as exc_info:
                df_loader()
            error_message = str(exc_info.value)
            assert (
                "Either the file is corrupted or this is not a parquet file"
                in error_message)
            assert str(partition) in error_message
예제 #7
0
    def test_save(self, dataset, mocked_csvs_in_s3):
        pds = PartitionedDataSet(mocked_csvs_in_s3, dataset)
        original_data = pd.DataFrame({"foo": 42, "bar": ["a", "b", None]})
        part_id = "new/data.csv"
        pds.save({part_id: original_data})

        s3 = s3fs.S3FileSystem()
        assert s3.exists("/".join([mocked_csvs_in_s3, part_id]))

        loaded_partitions = pds.load()
        assert part_id in loaded_partitions
        reloaded_data = loaded_partitions[part_id]()
        assert_frame_equal(reloaded_data, original_data)
예제 #8
0
    def test_save(self, dataset, local_csvs, suffix):
        pds = PartitionedDataSet(str(local_csvs),
                                 dataset,
                                 filename_suffix=suffix)
        original_data = pd.DataFrame({"foo": 42, "bar": ["a", "b", None]})
        part_id = "new/data"
        pds.save({part_id: original_data})

        assert (local_csvs / "new" / ("data" + suffix)).is_file()
        loaded_partitions = pds.load()
        assert part_id in loaded_partitions
        reloaded_data = loaded_partitions[part_id]()
        assert_frame_equal(reloaded_data, original_data)
예제 #9
0
    def test_load(self, dataset, local_csvs, partitioned_data_pandas, suffix,
                  expected_num_parts):
        pds = PartitionedDataSet(str(local_csvs),
                                 dataset,
                                 filename_suffix=suffix)
        loaded_partitions = pds.load()

        assert len(loaded_partitions.keys()) == expected_num_parts
        for partition_id, load_func in loaded_partitions.items():
            df = load_func()
            assert_frame_equal(df,
                               partitioned_data_pandas[partition_id + suffix])
            if suffix:
                assert not partition_id.endswith(suffix)
예제 #10
0
    def test_load_args(self, mocker):
        fake_partition_name = "fake_partition"
        mocked_filesystem = mocker.patch("fsspec.filesystem")
        mocked_find = mocked_filesystem.return_value.find
        mocked_find.return_value = [fake_partition_name]

        path = str(Path.cwd())
        load_args = {"maxdepth": 42, "withdirs": True}
        pds = PartitionedDataSet(path, "CSVLocalDataSet", load_args=load_args)
        mocker.patch.object(pds,
                            "_path_to_partition",
                            return_value=fake_partition_name)

        assert pds.load().keys() == {fake_partition_name}
        mocked_find.assert_called_once_with(path, **load_args)
예제 #11
0
    def test_release_instance_cache(self, local_csvs):
        """Test that cache invalidation does not affect other instances"""
        ds_a = PartitionedDataSet(str(local_csvs), "pandas.CSVDataSet")
        ds_a.load()
        ds_b = PartitionedDataSet(str(local_csvs), "pandas.CSVDataSet")
        ds_b.load()

        assert ds_a._partition_cache.currsize == 1
        assert ds_b._partition_cache.currsize == 1

        # invalidate cache of the dataset A
        ds_a.release()
        assert ds_a._partition_cache.currsize == 0
        # cache of the dataset B is unaffected
        assert ds_b._partition_cache.currsize == 1
예제 #12
0
 def test_versioned_dataset_not_allowed(self, dataset_config):
     pattern = (
         "`PartitionedDataSet` does not support versioning of the underlying "
         "dataset. Please remove `versioned` flag from the dataset definition."
     )
     with pytest.raises(DataSetError, match=re.escape(pattern)):
         PartitionedDataSet(str(Path.cwd()), dataset_config)
예제 #13
0
    def test_credentials(self, mocker, credentials, expected_pds_creds,
                         expected_dataset_creds):
        mocked_filesystem = mocker.patch("fsspec.filesystem")
        path = str(Path.cwd())
        pds = PartitionedDataSet(path,
                                 "pandas.CSVDataSet",
                                 credentials=credentials)

        assert mocked_filesystem.call_count == 2
        mocked_filesystem.assert_called_with("file", **expected_pds_creds)
        if expected_dataset_creds:
            assert pds._dataset_config[
                CREDENTIALS_KEY] == expected_dataset_creds
        else:
            assert CREDENTIALS_KEY not in pds._dataset_config

        str_repr = str(pds)

        def _assert_not_in_repr(value):
            if isinstance(value, dict):
                for k_, v_ in value.items():
                    _assert_not_in_repr(k_)
                    _assert_not_in_repr(v_)
            if value is not None:
                assert str(value) not in str_repr

        _assert_not_in_repr(credentials)
예제 #14
0
    def test_describe(self, dataset):
        path = f"s3://{BUCKET_NAME}/foo/bar"
        pds = PartitionedDataSet(path, dataset)

        assert f"path={path}" in str(pds)
        assert "dataset_type=CSVDataSet" in str(pds)
        assert "dataset_config" in str(pds)
예제 #15
0
    def test_describe(self, dataset):
        path = str(Path.cwd())
        pds = PartitionedDataSet(path, dataset)

        assert f"path={path}" in str(pds)
        assert "dataset_type=CSVDataSet" in str(pds)
        assert "dataset_config" in str(pds)
    def test_describe(self, dataset):
        path = str(Path.cwd())
        pds = PartitionedDataSet(path, dataset)

        assert "path={}".format(path) in str(pds)
        assert "dataset_type=CSVDataSet" in str(pds)
        if isinstance(dataset, dict) and dataset.keys() - {"type"}:
            assert "dataset_config" in str(pds)
        else:
            assert "dataset_config" not in str(pds)
    def test_describe(self, dataset):
        path = "s3://{}/foo/bar".format(BUCKET_NAME)
        pds = PartitionedDataSet(path, dataset)

        assert "path={}".format(path) in str(pds)
        assert "dataset_type=CSVDataSet" in str(pds)
        if isinstance(dataset, dict) and dataset.keys() - {"type"}:
            assert "dataset_config" in str(pds)
        else:
            assert "dataset_config" not in str(pds)
    def test_load_s3a(self, mocked_csvs_in_s3, partitioned_data_pandas,
                      mocker):
        s3a_path = "s3a://{}".format(mocked_csvs_in_s3.split("://", 1)[1])
        # any type is fine as long as it passes isinstance check
        # since _dataset_type is mocked later anyways
        pds = PartitionedDataSet(s3a_path, "pandas.CSVDataSet")
        assert pds._protocol == "s3a"

        mocked_ds = mocker.patch.object(pds, "_dataset_type")
        mocked_ds.__name__ = "mocked"
        loaded_partitions = pds.load()

        assert loaded_partitions.keys() == partitioned_data_pandas.keys()
        assert mocked_ds.call_count == len(loaded_partitions)
        expected = [
            mocker.call(filepath="{}/{}".format(s3a_path, partition_id))
            for partition_id in loaded_partitions
        ]
        mocked_ds.assert_has_calls(expected, any_order=True)
예제 #19
0
    def test_fs_args(self, mocker):
        fs_args = {"foo": "bar"}

        mocked_filesystem = mocker.patch("fsspec.filesystem")
        path = str(Path.cwd())
        pds = PartitionedDataSet(path, "pandas.CSVDataSet", fs_args=fs_args)

        assert mocked_filesystem.call_count == 2
        mocked_filesystem.assert_called_with("file", **fs_args)
        assert pds._dataset_config["fs_args"] == fs_args
    def test_save_s3a(self, mocked_csvs_in_s3, mocker):
        """Test that save works in case of s3a protocol"""
        s3a_path = "s3a://{}".format(mocked_csvs_in_s3.split("://", 1)[1])
        # any type is fine as long as it passes isinstance check
        # since _dataset_type is mocked later anyways
        pds = PartitionedDataSet(s3a_path,
                                 "pandas.CSVDataSet",
                                 filename_suffix=".csv")
        assert pds._protocol == "s3a"

        mocked_ds = mocker.patch.object(pds, "_dataset_type")
        mocked_ds.__name__ = "mocked"
        new_partition = "new/data"
        data = "data"

        pds.save({new_partition: data})
        mocked_ds.assert_called_once_with(
            filepath="{}/{}.csv".format(s3a_path, new_partition))
        mocked_ds.return_value.save.assert_called_once_with(data)
예제 #21
0
 def test_dataset_creds_deprecated(self, pds_config,
                                   expected_dataset_creds):
     """Check that the deprecation warning is emitted if dataset credentials
     were specified the old way (using `dataset_credentials` key)"""
     pattern = (
         "Support for `dataset_credentials` key in the credentials is now "
         "deprecated and will be removed in the next version. Please specify "
         "the dataset credentials explicitly inside the dataset config.")
     with pytest.warns(DeprecationWarning, match=re.escape(pattern)):
         pds = PartitionedDataSet(path=str(Path.cwd()), **pds_config)
     assert pds._dataset_config["credentials"] == expected_dataset_creds
예제 #22
0
    def test_release(self, dataset, local_csvs):
        partition_to_remove = "p2.csv"
        pds = PartitionedDataSet(str(local_csvs), dataset)
        initial_load = pds.load()
        assert partition_to_remove in initial_load

        (local_csvs / partition_to_remove).unlink()
        cached_load = pds.load()
        assert initial_load.keys() == cached_load.keys()

        pds.release()
        load_after_release = pds.load()
        assert initial_load.keys() ^ load_after_release.keys() == {
            partition_to_remove
        }
예제 #23
0
    def test_release(self, dataset, mocked_csvs_in_s3):
        partition_to_remove = "p2.csv"
        pds = PartitionedDataSet(mocked_csvs_in_s3, dataset)
        initial_load = pds.load()
        assert partition_to_remove in initial_load

        s3 = s3fs.S3FileSystem()
        s3.rm("/".join([mocked_csvs_in_s3, partition_to_remove]))
        cached_load = pds.load()
        assert initial_load.keys() == cached_load.keys()

        pds.release()
        load_after_release = pds.load()
        assert initial_load.keys() ^ load_after_release.keys() == {
            partition_to_remove
        }
 def test_credentials_log_warning(self, caplog):
     """Check that the warning is logged if the dataset credentials will overwrite
     the top-level ones"""
     pds = PartitionedDataSet(
         path=str(Path.cwd()),
         dataset={
             "type": CSVDataSet,
             "credentials": {
                 "secret": "dataset"
             }
         },
         credentials={"secret": "global"},
     )
     log_message = (
         "Top-level credentials will not propagate into the underlying dataset "
         "since credentials were explicitly defined in the dataset config.")
     assert caplog.record_tuples == [("kedro.io.core", logging.WARNING,
                                      log_message)]
     assert pds._dataset_config["credentials"] == {"secret": "dataset"}
예제 #25
0
 def test_fs_args_log_warning(self, caplog):
     """Check that the warning is logged if the dataset filesystem
     arguments will overwrite the top-level ones"""
     pds = PartitionedDataSet(
         path=str(Path.cwd()),
         dataset={
             "type": CSVDataSet,
             "fs_args": {
                 "args": "dataset"
             }
         },
         fs_args={"args": "dataset"},
     )
     log_message = KEY_PROPAGATION_WARNING % {
         "keys": "filesystem arguments",
         "target": "underlying dataset",
     }
     assert caplog.record_tuples == [("kedro.io.core", logging.WARNING,
                                      log_message)]
     assert pds._dataset_config["fs_args"] == {"args": "dataset"}
예제 #26
0
    def test_save_invalidates_cache(self, local_csvs, mocker):
        """Test that save calls invalidate partition cache"""
        pds = PartitionedDataSet(str(local_csvs), "pandas.CSVDataSet")
        mocked_fs_invalidate = mocker.patch.object(pds._filesystem,
                                                   "invalidate_cache")
        first_load = pds.load()
        assert pds._partition_cache.currsize == 1
        mocked_fs_invalidate.assert_not_called()

        # save clears cache
        data = pd.DataFrame({"foo": 42, "bar": ["a", "b", None]})
        new_partition = "new/data.csv"
        pds.save({new_partition: data})
        assert pds._partition_cache.currsize == 0
        # it seems that `_filesystem.invalidate_cache` calls itself inside,
        # resulting in not one, but 2 mock calls
        # hence using `assert_any_call` instead of `assert_called_once_with`
        mocked_fs_invalidate.assert_any_call(pds._normalized_path)

        # new load returns new partition too
        second_load = pds.load()
        assert new_partition not in first_load
        assert new_partition in second_load
예제 #27
0
 def test_filepath_arg_warning(self, pds_config, filepath_arg):
     pattern = (
         f"`{filepath_arg}` key must not be specified in the dataset definition as it "
         f"will be overwritten by partition path")
     with pytest.warns(UserWarning, match=re.escape(pattern)):
         PartitionedDataSet(**pds_config)
예제 #28
0
 def test_invalid_dataset_config(self, dataset_config, error_pattern):
     with pytest.raises(DataSetError, match=error_pattern):
         PartitionedDataSet(str(Path.cwd()), dataset_config)
예제 #29
0
    def test_no_partitions(self, tmpdir):
        pds = PartitionedDataSet(str(tmpdir), "pandas.CSVDataSet")

        pattern = re.escape(f"No partitions found in `{tmpdir}`")
        with pytest.raises(DataSetError, match=pattern):
            pds.load()
예제 #30
0
 def test_dataset_creds(self, pds_config, expected_ds_creds, global_creds):
     """Check that global credentials do not interfere dataset credentials."""
     pds = PartitionedDataSet(path=str(Path.cwd()), **pds_config)
     assert pds._dataset_config["credentials"] == expected_ds_creds
     assert pds._credentials == global_creds