Exemplo n.º 1
class PartitionedDataSet(AbstractDataSet):
    # pylint: disable=too-many-instance-attributes,protected-access
    """``PartitionedDataSet`` loads and saves partitioned file-like data using the
    underlying dataset definition. For filesystem level operations it uses `fsspec`:


        >>> import pandas as pd
        >>> from kedro.io import PartitionedDataSet
        >>> # these credentials will be passed to both 'fsspec.filesystem()' call
        >>> # and the dataset initializer
        >>> credentials = {"key1": "secret1", "key2": "secret2"}
        >>> data_set = PartitionedDataSet(
        >>>     path="s3://bucket-name/path/to/folder",
        >>>     dataset="CSVDataSet",
        >>>     credentials=credentials
        >>> )
        >>> loaded = data_set.load()
        >>> # assert isinstance(loaded, dict)
        >>> combine_all = pd.DataFrame()
        >>> for partition_id, partition_load_func in loaded.items():
        >>>     partition_data = partition_load_func()
        >>>     combine_all = pd.concat(
        >>>         [combine_all, partition_data], ignore_index=True, sort=True
        >>>     )
        >>> new_data = pd.DataFrame({"new": [1, 2]})
        >>> # creates "s3://bucket-name/path/to/folder/new/partition.csv"
        >>> data_set.save({"new/partition.csv": new_data})
    def __init__(  # pylint: disable=too-many-arguments
        path: str,
        dataset: Union[str, Type[AbstractDataSet], Dict[str, Any]],
        filepath_arg: str = "filepath",
        filename_suffix: str = "",
        credentials: Dict[str, Any] = None,
        load_args: Dict[str, Any] = None,
        fs_args: Dict[str, Any] = None,
        """Creates a new instance of ``PartitionedDataSet``.

            path: Path to the folder containing partitioned data.
                If path starts with the protocol (e.g., ``s3://``) then the
                corresponding ``fsspec`` concrete filesystem implementation will
                be used. If protocol is not specified,
                ``fsspec.implementations.local.LocalFileSystem`` will be used.
                **Note:** Some concrete implementations are bundled with ``fsspec``,
                while others (like ``s3`` or ``gcs``) must be installed separately
                prior to usage of the ``PartitionedDataSet``.
            dataset: Underlying dataset definition. This is used to instantiate
                the dataset for each file located inside the ``path``.
                Accepted formats are:
                a) object of a class that inherits from ``AbstractDataSet``
                b) a string representing a fully qualified class name to such class
                c) a dictionary with ``type`` key pointing to a string from b),
                other keys are passed to the Dataset initializer.
                Credentials for the dataset can be explicitly specified in
                this configuration.
            filepath_arg: Underlying dataset initializer argument that will
                contain a path to each corresponding partition file.
                If unspecified, defaults to "filepath".
            filename_suffix: If specified, only partitions that end with this
                string will be processed.
            credentials: Protocol-specific options that will be passed to
                and the dataset initializer. If the dataset config contains
                explicit credentials spec, then such spec will take precedence.
                All possible credentials management scenarios are documented here:
            load_args: Keyword arguments to be passed into ``find()`` method of
                the filesystem implementation.
            fs_args: Extra arguments to pass into underlying filesystem class constructor
                (e.g. `{"project": "my-project"}` for ``GCSFileSystem``)

            DataSetError: If versioning is enabled for the underlying dataset.
        # pylint: disable=import-outside-toplevel
        from fsspec.utils import infer_storage_options  # for performance reasons


        self._path = path
        self._filename_suffix = filename_suffix
        self._protocol = infer_storage_options(self._path)["protocol"]
        self._partition_cache = Cache(maxsize=1)

        dataset = dataset if isinstance(dataset, dict) else {"type": dataset}
        self._dataset_type, self._dataset_config = parse_dataset_definition(
        if VERSION_KEY in self._dataset_config:
            raise DataSetError(
                "`{}` does not support versioning of the underlying dataset. "
                "Please remove `{}` flag from the dataset definition.".format(
                    self.__class__.__name__, VERSIONED_FLAG_KEY))

        if credentials:
            if CREDENTIALS_KEY in self._dataset_config:
                        "keys": CREDENTIALS_KEY,
                        "target": "underlying dataset"
                self._dataset_config[CREDENTIALS_KEY] = deepcopy(credentials)

        self._credentials = deepcopy(credentials) or {}

        self._fs_args = deepcopy(fs_args) or {}
        if self._fs_args:
            if "fs_args" in self._dataset_config:
                        "keys": "filesystem arguments",
                        "target": "underlying dataset"
                self._dataset_config["fs_args"] = deepcopy(self._fs_args)

        self._filepath_arg = filepath_arg
        if self._filepath_arg in self._dataset_config:
                "`{}` key must not be specified in the dataset definition as it "
                "will be overwritten by partition path".format(

        self._load_args = deepcopy(load_args) or {}
        self._sep = self._filesystem.sep
        # since some filesystem implementations may implement a global cache

    def _filesystem(self):
        # for performance reasons
        import fsspec  # pylint: disable=import-outside-toplevel

        protocol = "s3" if self._protocol in S3_PROTOCOLS else self._protocol
        return fsspec.filesystem(protocol, **self._credentials,

    def _normalized_path(self) -> str:
        if self._protocol in S3_PROTOCOLS:
            return urlparse(self._path)._replace(scheme="s3").geturl()
        return self._path

    def _list_partitions(self) -> List[str]:
        return [
            path for path in self._filesystem.find(self._normalized_path, **
            if path.endswith(self._filename_suffix)

    def _join_protocol(self, path: str) -> str:
        if self._path.startswith(
                self._protocol) and not path.startswith(self._protocol):
            return f"{self._protocol}://{path}"
        return path

    def _partition_to_path(self, path: str):
        dir_path = self._path.rstrip(self._sep)
        path = path.lstrip(self._sep)
        full_path = self._sep.join([dir_path, path]) + self._filename_suffix
        return full_path

    def _path_to_partition(self, path: str) -> str:
        dir_path = self._filesystem._strip_protocol(self._normalized_path)
        path = path.split(dir_path, 1).pop().lstrip(self._sep)
        if self._filename_suffix and path.endswith(self._filename_suffix):
            path = path[:-len(self._filename_suffix)]
        return path

    def _load(self) -> Dict[str, Callable[[], Any]]:
        partitions = {}

        for partition in self._list_partitions():
            kwargs = deepcopy(self._dataset_config)
            # join the protocol back since PySpark may rely on it
            kwargs[self._filepath_arg] = self._join_protocol(partition)
            dataset = self._dataset_type(**kwargs)  # type: ignore
            partition_id = self._path_to_partition(partition)
            partitions[partition_id] = dataset.load

        if not partitions:
            raise DataSetError(f"No partitions found in `{self._path}`")

        return partitions

    def _save(self, data: Dict[str, Any]) -> None:
        for partition_id, partition_data in sorted(data.items()):
            kwargs = deepcopy(self._dataset_config)
            partition = self._partition_to_path(partition_id)
            # join the protocol back since tools like PySpark may rely on it
            kwargs[self._filepath_arg] = self._join_protocol(partition)
            dataset = self._dataset_type(**kwargs)  # type: ignore

    def _describe(self) -> Dict[str, Any]:
        clean_dataset_config = ({
            k: v
            for k, v in self._dataset_config.items() if k != CREDENTIALS_KEY
        } if isinstance(self._dataset_config, dict) else self._dataset_config)
        return dict(

    def _invalidate_caches(self):

    def _exists(self) -> bool:
        return bool(self._list_partitions())

    def _release(self) -> None:
Exemplo n.º 2
class AbstractVersionedDataSet(AbstractDataSet, abc.ABC):
    ``AbstractVersionedDataSet`` is the base class for all versioned data set
    implementations. All data sets that implement versioning should extend this
    abstract class and implement the methods marked as abstract.


        >>> from pathlib import Path, PurePosixPath
        >>> import pandas as pd
        >>> from kedro.io import AbstractVersionedDataSet
        >>> class MyOwnDataSet(AbstractVersionedDataSet):
        >>>     def __init__(self, filepath, version, param1, param2=True):
        >>>         super().__init__(PurePosixPath(filepath), version)
        >>>         self._param1 = param1
        >>>         self._param2 = param2
        >>>     def _load(self) -> pd.DataFrame:
        >>>         load_path = self._get_load_path()
        >>>         return pd.read_csv(load_path)
        >>>     def _save(self, df: pd.DataFrame) -> None:
        >>>         save_path = self._get_save_path()
        >>>         df.to_csv(str(save_path))
        >>>     def _exists(self) -> bool:
        >>>         path = self._get_load_path()
        >>>         return Path(path.as_posix()).exists()
        >>>     def _describe(self):
        >>>         return dict(version=self._version, param1=self._param1, param2=self._param2)

    Example catalog.yml specification:

            type: <path-to-my-own-dataset>.MyOwnDataSet
            filepath: data/01_raw/my_data.csv
            versioned: true
            param1: <param1-value> # param1 is a required argument
            # param2 will be True by default

    def __init__(
        filepath: PurePosixPath,
        version: Optional[Version],
        exists_function: Callable[[str], bool] = None,
        glob_function: Callable[[str], List[str]] = None,
        """Creates a new instance of ``AbstractVersionedDataSet``.

            filepath: Filepath in POSIX format to a file.
            version: If specified, should be an instance of
                ``kedro.io.core.Version``. If its ``load`` attribute is
                None, the latest version will be loaded. If its ``save``
                attribute is None, save version will be autogenerated.
            exists_function: Function that is used for determining whether
                a path exists in a filesystem.
            glob_function: Function that is used for finding all paths
                in a filesystem, which match a given pattern.
        self._filepath = filepath
        self._version = version
        self._exists_function = exists_function or _local_exists
        self._glob_function = glob_function or iglob
        # 1 entry for load version, 1 for save version
        self._version_cache = Cache(maxsize=2)

    # 'key' is set to prevent cache key overlapping for load and save:
    # https://cachetools.readthedocs.io/en/stable/#cachetools.cachedmethod
    @cachedmethod(cache=attrgetter("_version_cache"), key=partial(hashkey, "load"))
    def _fetch_latest_load_version(self) -> str:
        # When load version is unpinned, fetch the most recent existing
        # version from the given path.
        pattern = str(self._get_versioned_path("*"))
        version_paths = sorted(self._glob_function(pattern), reverse=True)
        most_recent = next(
            (path for path in version_paths if self._exists_function(path)), None

        if not most_recent:
            raise VersionNotFoundError(f"Did not find any versions for {self}")

        return PurePath(most_recent).parent.name

    # 'key' is set to prevent cache key overlapping for load and save:
    # https://cachetools.readthedocs.io/en/stable/#cachetools.cachedmethod
    @cachedmethod(cache=attrgetter("_version_cache"), key=partial(hashkey, "save"))
    def _fetch_latest_save_version(self) -> str:  # pylint: disable=no-self-use
        """Generate and cache the current save version"""
        return generate_timestamp()

    def resolve_load_version(self) -> Optional[str]:
        """Compute the version the dataset should be loaded with."""
        if not self._version:
            return None
        if self._version.load:
            return self._version.load
        return self._fetch_latest_load_version()

    def _get_load_path(self) -> PurePosixPath:
        if not self._version:
            # When versioning is disabled, load from original filepath
            return self._filepath

        load_version = self.resolve_load_version()
        return self._get_versioned_path(load_version)  # type: ignore

    def resolve_save_version(self) -> Optional[str]:
        """Compute the version the dataset should be saved with."""
        if not self._version:
            return None
        if self._version.save:
            return self._version.save
        return self._fetch_latest_save_version()

    def _get_save_path(self) -> PurePosixPath:
        if not self._version:
            # When versioning is disabled, return original filepath
            return self._filepath

        save_version = self.resolve_save_version()
        versioned_path = self._get_versioned_path(save_version)  # type: ignore

        if self._exists_function(str(versioned_path)):
            raise DataSetError(
                f"Save path `{versioned_path}` for {str(self)} must not exist if "
                f"versioning is enabled."

        return versioned_path

    def _get_versioned_path(self, version: str) -> PurePosixPath:
        return self._filepath / version / self._filepath.name

    def load(self) -> Any:
        self.resolve_load_version()  # Make sure last load version is set
        return super().load()

    def save(self, data: Any) -> None:
        save_version = self.resolve_save_version()  # Make sure last save version is set
        except (FileNotFoundError, NotADirectoryError) as err:
            # FileNotFoundError raised in Win, NotADirectoryError raised in Unix
            _default_version = "YYYY-MM-DDThh.mm.ss.sssZ"
            raise DataSetError(
                f"Cannot save versioned dataset `{self._filepath.name}` to "
                f"`{self._filepath.parent.as_posix()}` because a file with the same "
                f"name already exists in the directory. This is likely because "
                f"versioning was enabled on a dataset already saved previously. Either "
                f"remove `{self._filepath.name}` from the directory or manually "
                f"convert it into a versioned dataset by placing it in a versioned "
                f"directory (e.g. with default versioning format "
            ) from err

        load_version = self.resolve_load_version()
        if load_version != save_version:
                _CONSISTENCY_WARNING.format(save_version, load_version, str(self))

    def exists(self) -> bool:
        """Checks whether a data set's output already exists by calling
        the provided _exists() method.

            Flag indicating whether the output already exists.

            DataSetError: when underlying exists method raises error.

        self._logger.debug("Checking whether target of %s exists", str(self))
            return self._exists()
        except VersionNotFoundError:
            return False
        except Exception as exc:  # SKIP_IF_NO_SPARK
            message = (
                f"Failed during exists check for data set {str(self)}.\n{str(exc)}"
            raise DataSetError(message) from exc

    def _release(self) -> None:
Exemplo n.º 3
 def clear(self):
     with self.__timer as time: