class PartitionedDataSet(AbstractDataSet): # pylint: disable=too-many-instance-attributes,protected-access """``PartitionedDataSet`` loads and saves partitioned file-like data using the underlying dataset definition. For filesystem level operations it uses `fsspec`: https://github.com/intake/filesystem_spec. Example: :: >>> import pandas as pd >>> from kedro.io import PartitionedDataSet >>> >>> # these credentials will be passed to both 'fsspec.filesystem()' call >>> # and the dataset initializer >>> credentials = {"key1": "secret1", "key2": "secret2"} >>> >>> data_set = PartitionedDataSet( >>> path="s3://bucket-name/path/to/folder", >>> dataset="CSVDataSet", >>> credentials=credentials >>> ) >>> loaded = data_set.load() >>> # assert isinstance(loaded, dict) >>> >>> combine_all = pd.DataFrame() >>> >>> for partition_id, partition_load_func in loaded.items(): >>> partition_data = partition_load_func() >>> combine_all = pd.concat( >>> [combine_all, partition_data], ignore_index=True, sort=True >>> ) >>> >>> new_data = pd.DataFrame({"new": [1, 2]}) >>> # creates "s3://bucket-name/path/to/folder/new/partition.csv" >>> data_set.save({"new/partition.csv": new_data}) >>> """ def __init__( # pylint: disable=too-many-arguments self, path: str, dataset: Union[str, Type[AbstractDataSet], Dict[str, Any]], filepath_arg: str = "filepath", filename_suffix: str = "", credentials: Dict[str, Any] = None, load_args: Dict[str, Any] = None, fs_args: Dict[str, Any] = None, ): """Creates a new instance of ``PartitionedDataSet``. Args: path: Path to the folder containing partitioned data. If path starts with the protocol (e.g., ``s3://``) then the corresponding ``fsspec`` concrete filesystem implementation will be used. If protocol is not specified, ``fsspec.implementations.local.LocalFileSystem`` will be used. **Note:** Some concrete implementations are bundled with ``fsspec``, while others (like ``s3`` or ``gcs``) must be installed separately prior to usage of the ``PartitionedDataSet``. dataset: Underlying dataset definition. This is used to instantiate the dataset for each file located inside the ``path``. Accepted formats are: a) object of a class that inherits from ``AbstractDataSet`` b) a string representing a fully qualified class name to such class c) a dictionary with ``type`` key pointing to a string from b), other keys are passed to the Dataset initializer. Credentials for the dataset can be explicitly specified in this configuration. filepath_arg: Underlying dataset initializer argument that will contain a path to each corresponding partition file. If unspecified, defaults to "filepath". filename_suffix: If specified, only partitions that end with this string will be processed. credentials: Protocol-specific options that will be passed to ``fsspec.filesystem`` https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.filesystem and the dataset initializer. If the dataset config contains explicit credentials spec, then such spec will take precedence. All possible credentials management scenarios are documented here: https://kedro.readthedocs.io/en/stable/04_user_guide/08_advanced_io.html#partitioned-dataset-credentials load_args: Keyword arguments to be passed into ``find()`` method of the filesystem implementation. fs_args: Extra arguments to pass into underlying filesystem class constructor (e.g. `{"project": "my-project"}` for ``GCSFileSystem``) Raises: DataSetError: If versioning is enabled for the underlying dataset. """ # pylint: disable=import-outside-toplevel from fsspec.utils import infer_storage_options # for performance reasons super().__init__() self._path = path self._filename_suffix = filename_suffix self._protocol = infer_storage_options(self._path)["protocol"] self._partition_cache = Cache(maxsize=1) dataset = dataset if isinstance(dataset, dict) else {"type": dataset} self._dataset_type, self._dataset_config = parse_dataset_definition( dataset) if VERSION_KEY in self._dataset_config: raise DataSetError( "`{}` does not support versioning of the underlying dataset. " "Please remove `{}` flag from the dataset definition.".format( self.__class__.__name__, VERSIONED_FLAG_KEY)) if credentials: if CREDENTIALS_KEY in self._dataset_config: self._logger.warning( KEY_PROPAGATION_WARNING, { "keys": CREDENTIALS_KEY, "target": "underlying dataset" }, ) else: self._dataset_config[CREDENTIALS_KEY] = deepcopy(credentials) self._credentials = deepcopy(credentials) or {} self._fs_args = deepcopy(fs_args) or {} if self._fs_args: if "fs_args" in self._dataset_config: self._logger.warning( KEY_PROPAGATION_WARNING, { "keys": "filesystem arguments", "target": "underlying dataset" }, ) else: self._dataset_config["fs_args"] = deepcopy(self._fs_args) self._filepath_arg = filepath_arg if self._filepath_arg in self._dataset_config: warn( "`{}` key must not be specified in the dataset definition as it " "will be overwritten by partition path".format( self._filepath_arg)) self._load_args = deepcopy(load_args) or {} self._sep = self._filesystem.sep # since some filesystem implementations may implement a global cache self._invalidate_caches() @property def _filesystem(self): # for performance reasons import fsspec # pylint: disable=import-outside-toplevel protocol = "s3" if self._protocol in S3_PROTOCOLS else self._protocol return fsspec.filesystem(protocol, **self._credentials, **self._fs_args) @property def _normalized_path(self) -> str: if self._protocol in S3_PROTOCOLS: return urlparse(self._path)._replace(scheme="s3").geturl() return self._path @cachedmethod(cache=operator.attrgetter("_partition_cache")) def _list_partitions(self) -> List[str]: return [ path for path in self._filesystem.find(self._normalized_path, ** self._load_args) if path.endswith(self._filename_suffix) ] def _join_protocol(self, path: str) -> str: if self._path.startswith( self._protocol) and not path.startswith(self._protocol): return f"{self._protocol}://{path}" return path def _partition_to_path(self, path: str): dir_path = self._path.rstrip(self._sep) path = path.lstrip(self._sep) full_path = self._sep.join([dir_path, path]) + self._filename_suffix return full_path def _path_to_partition(self, path: str) -> str: dir_path = self._filesystem._strip_protocol(self._normalized_path) path = path.split(dir_path, 1).pop().lstrip(self._sep) if self._filename_suffix and path.endswith(self._filename_suffix): path = path[:-len(self._filename_suffix)] return path def _load(self) -> Dict[str, Callable[[], Any]]: partitions = {} for partition in self._list_partitions(): kwargs = deepcopy(self._dataset_config) # join the protocol back since PySpark may rely on it kwargs[self._filepath_arg] = self._join_protocol(partition) dataset = self._dataset_type(**kwargs) # type: ignore partition_id = self._path_to_partition(partition) partitions[partition_id] = dataset.load if not partitions: raise DataSetError(f"No partitions found in `{self._path}`") return partitions def _save(self, data: Dict[str, Any]) -> None: for partition_id, partition_data in sorted(data.items()): kwargs = deepcopy(self._dataset_config) partition = self._partition_to_path(partition_id) # join the protocol back since tools like PySpark may rely on it kwargs[self._filepath_arg] = self._join_protocol(partition) dataset = self._dataset_type(**kwargs) # type: ignore dataset.save(partition_data) self._invalidate_caches() def _describe(self) -> Dict[str, Any]: clean_dataset_config = ({ k: v for k, v in self._dataset_config.items() if k != CREDENTIALS_KEY } if isinstance(self._dataset_config, dict) else self._dataset_config) return dict( path=self._path, dataset_type=self._dataset_type.__name__, dataset_config=clean_dataset_config, ) def _invalidate_caches(self): self._partition_cache.clear() self._filesystem.invalidate_cache(self._normalized_path) def _exists(self) -> bool: return bool(self._list_partitions()) def _release(self) -> None: super()._release() self._invalidate_caches()
class AbstractVersionedDataSet(AbstractDataSet, abc.ABC): """ ``AbstractVersionedDataSet`` is the base class for all versioned data set implementations. All data sets that implement versioning should extend this abstract class and implement the methods marked as abstract. Example: :: >>> from pathlib import Path, PurePosixPath >>> import pandas as pd >>> from kedro.io import AbstractVersionedDataSet >>> >>> >>> class MyOwnDataSet(AbstractVersionedDataSet): >>> def __init__(self, filepath, version, param1, param2=True): >>> super().__init__(PurePosixPath(filepath), version) >>> self._param1 = param1 >>> self._param2 = param2 >>> >>> def _load(self) -> pd.DataFrame: >>> load_path = self._get_load_path() >>> return pd.read_csv(load_path) >>> >>> def _save(self, df: pd.DataFrame) -> None: >>> save_path = self._get_save_path() >>> df.to_csv(str(save_path)) >>> >>> def _exists(self) -> bool: >>> path = self._get_load_path() >>> return Path(path.as_posix()).exists() >>> >>> def _describe(self): >>> return dict(version=self._version, param1=self._param1, param2=self._param2) Example catalog.yml specification: :: my_dataset: type: <path-to-my-own-dataset>.MyOwnDataSet filepath: data/01_raw/my_data.csv versioned: true param1: <param1-value> # param1 is a required argument # param2 will be True by default """ def __init__( self, filepath: PurePosixPath, version: Optional[Version], exists_function: Callable[[str], bool] = None, glob_function: Callable[[str], List[str]] = None, ): """Creates a new instance of ``AbstractVersionedDataSet``. Args: filepath: Filepath in POSIX format to a file. version: If specified, should be an instance of ``kedro.io.core.Version``. If its ``load`` attribute is None, the latest version will be loaded. If its ``save`` attribute is None, save version will be autogenerated. exists_function: Function that is used for determining whether a path exists in a filesystem. glob_function: Function that is used for finding all paths in a filesystem, which match a given pattern. """ self._filepath = filepath self._version = version self._exists_function = exists_function or _local_exists self._glob_function = glob_function or iglob # 1 entry for load version, 1 for save version self._version_cache = Cache(maxsize=2) # 'key' is set to prevent cache key overlapping for load and save: # https://cachetools.readthedocs.io/en/stable/#cachetools.cachedmethod @cachedmethod(cache=attrgetter("_version_cache"), key=partial(hashkey, "load")) def _fetch_latest_load_version(self) -> str: # When load version is unpinned, fetch the most recent existing # version from the given path. pattern = str(self._get_versioned_path("*")) version_paths = sorted(self._glob_function(pattern), reverse=True) most_recent = next( (path for path in version_paths if self._exists_function(path)), None ) if not most_recent: raise VersionNotFoundError(f"Did not find any versions for {self}") return PurePath(most_recent).parent.name # 'key' is set to prevent cache key overlapping for load and save: # https://cachetools.readthedocs.io/en/stable/#cachetools.cachedmethod @cachedmethod(cache=attrgetter("_version_cache"), key=partial(hashkey, "save")) def _fetch_latest_save_version(self) -> str: # pylint: disable=no-self-use """Generate and cache the current save version""" return generate_timestamp() def resolve_load_version(self) -> Optional[str]: """Compute the version the dataset should be loaded with.""" if not self._version: return None if self._version.load: return self._version.load return self._fetch_latest_load_version() def _get_load_path(self) -> PurePosixPath: if not self._version: # When versioning is disabled, load from original filepath return self._filepath load_version = self.resolve_load_version() return self._get_versioned_path(load_version) # type: ignore def resolve_save_version(self) -> Optional[str]: """Compute the version the dataset should be saved with.""" if not self._version: return None if self._version.save: return self._version.save return self._fetch_latest_save_version() def _get_save_path(self) -> PurePosixPath: if not self._version: # When versioning is disabled, return original filepath return self._filepath save_version = self.resolve_save_version() versioned_path = self._get_versioned_path(save_version) # type: ignore if self._exists_function(str(versioned_path)): raise DataSetError( f"Save path `{versioned_path}` for {str(self)} must not exist if " f"versioning is enabled." ) return versioned_path def _get_versioned_path(self, version: str) -> PurePosixPath: return self._filepath / version / self._filepath.name def load(self) -> Any: self.resolve_load_version() # Make sure last load version is set return super().load() def save(self, data: Any) -> None: self._version_cache.clear() save_version = self.resolve_save_version() # Make sure last save version is set try: super().save(data) except (FileNotFoundError, NotADirectoryError) as err: # FileNotFoundError raised in Win, NotADirectoryError raised in Unix _default_version = "YYYY-MM-DDThh.mm.ss.sssZ" raise DataSetError( f"Cannot save versioned dataset `{self._filepath.name}` to " f"`{self._filepath.parent.as_posix()}` because a file with the same " f"name already exists in the directory. This is likely because " f"versioning was enabled on a dataset already saved previously. Either " f"remove `{self._filepath.name}` from the directory or manually " f"convert it into a versioned dataset by placing it in a versioned " f"directory (e.g. with default versioning format " f"`{self._filepath.as_posix()}/{_default_version}/{self._filepath.name}" f"`)." ) from err load_version = self.resolve_load_version() if load_version != save_version: warnings.warn( _CONSISTENCY_WARNING.format(save_version, load_version, str(self)) ) def exists(self) -> bool: """Checks whether a data set's output already exists by calling the provided _exists() method. Returns: Flag indicating whether the output already exists. Raises: DataSetError: when underlying exists method raises error. """ self._logger.debug("Checking whether target of %s exists", str(self)) try: return self._exists() except VersionNotFoundError: return False except Exception as exc: # SKIP_IF_NO_SPARK message = ( f"Failed during exists check for data set {str(self)}.\n{str(exc)}" ) raise DataSetError(message) from exc def _release(self) -> None: super()._release() self._version_cache.clear()
def clear(self): with self.__timer as time: self.expire(time) Cache.clear(self)