Ejemplo n.º 1
0
class PartitionedDataSet(AbstractDataSet):
    # pylint: disable=too-many-instance-attributes,protected-access
    """``PartitionedDataSet`` loads and saves partitioned file-like data using the
    underlying dataset definition. For filesystem level operations it uses `fsspec`:
    https://github.com/intake/filesystem_spec.

    Example:
    ::

        >>> import pandas as pd
        >>> from kedro.io import PartitionedDataSet
        >>>
        >>> # these credentials will be passed to both 'fsspec.filesystem()' call
        >>> # and the dataset initializer
        >>> credentials = {"key1": "secret1", "key2": "secret2"}
        >>>
        >>> data_set = PartitionedDataSet(
        >>>     path="s3://bucket-name/path/to/folder",
        >>>     dataset="CSVDataSet",
        >>>     credentials=credentials
        >>> )
        >>> loaded = data_set.load()
        >>> # assert isinstance(loaded, dict)
        >>>
        >>> combine_all = pd.DataFrame()
        >>>
        >>> for partition_id, partition_load_func in loaded.items():
        >>>     partition_data = partition_load_func()
        >>>     combine_all = pd.concat(
        >>>         [combine_all, partition_data], ignore_index=True, sort=True
        >>>     )
        >>>
        >>> new_data = pd.DataFrame({"new": [1, 2]})
        >>> # creates "s3://bucket-name/path/to/folder/new/partition.csv"
        >>> data_set.save({"new/partition.csv": new_data})
        >>>
    """
    def __init__(  # pylint: disable=too-many-arguments
        self,
        path: str,
        dataset: Union[str, Type[AbstractDataSet], Dict[str, Any]],
        filepath_arg: str = "filepath",
        filename_suffix: str = "",
        credentials: Dict[str, Any] = None,
        load_args: Dict[str, Any] = None,
        fs_args: Dict[str, Any] = None,
    ):
        """Creates a new instance of ``PartitionedDataSet``.

        Args:
            path: Path to the folder containing partitioned data.
                If path starts with the protocol (e.g., ``s3://``) then the
                corresponding ``fsspec`` concrete filesystem implementation will
                be used. If protocol is not specified,
                ``fsspec.implementations.local.LocalFileSystem`` will be used.
                **Note:** Some concrete implementations are bundled with ``fsspec``,
                while others (like ``s3`` or ``gcs``) must be installed separately
                prior to usage of the ``PartitionedDataSet``.
            dataset: Underlying dataset definition. This is used to instantiate
                the dataset for each file located inside the ``path``.
                Accepted formats are:
                a) object of a class that inherits from ``AbstractDataSet``
                b) a string representing a fully qualified class name to such class
                c) a dictionary with ``type`` key pointing to a string from b),
                other keys are passed to the Dataset initializer.
                Credentials for the dataset can be explicitly specified in
                this configuration.
            filepath_arg: Underlying dataset initializer argument that will
                contain a path to each corresponding partition file.
                If unspecified, defaults to "filepath".
            filename_suffix: If specified, only partitions that end with this
                string will be processed.
            credentials: Protocol-specific options that will be passed to
                ``fsspec.filesystem``
                https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.filesystem
                and the dataset initializer. If the dataset config contains
                explicit credentials spec, then such spec will take precedence.
                All possible credentials management scenarios are documented here:
                https://kedro.readthedocs.io/en/stable/04_user_guide/08_advanced_io.html#partitioned-dataset-credentials
            load_args: Keyword arguments to be passed into ``find()`` method of
                the filesystem implementation.
            fs_args: Extra arguments to pass into underlying filesystem class constructor
                (e.g. `{"project": "my-project"}` for ``GCSFileSystem``)

        Raises:
            DataSetError: If versioning is enabled for the underlying dataset.
        """
        # pylint: disable=import-outside-toplevel
        from fsspec.utils import infer_storage_options  # for performance reasons

        super().__init__()

        self._path = path
        self._filename_suffix = filename_suffix
        self._protocol = infer_storage_options(self._path)["protocol"]
        self._partition_cache = Cache(maxsize=1)

        dataset = dataset if isinstance(dataset, dict) else {"type": dataset}
        self._dataset_type, self._dataset_config = parse_dataset_definition(
            dataset)
        if VERSION_KEY in self._dataset_config:
            raise DataSetError(
                "`{}` does not support versioning of the underlying dataset. "
                "Please remove `{}` flag from the dataset definition.".format(
                    self.__class__.__name__, VERSIONED_FLAG_KEY))

        if credentials:
            if CREDENTIALS_KEY in self._dataset_config:
                self._logger.warning(
                    KEY_PROPAGATION_WARNING,
                    {
                        "keys": CREDENTIALS_KEY,
                        "target": "underlying dataset"
                    },
                )
            else:
                self._dataset_config[CREDENTIALS_KEY] = deepcopy(credentials)

        self._credentials = deepcopy(credentials) or {}

        self._fs_args = deepcopy(fs_args) or {}
        if self._fs_args:
            if "fs_args" in self._dataset_config:
                self._logger.warning(
                    KEY_PROPAGATION_WARNING,
                    {
                        "keys": "filesystem arguments",
                        "target": "underlying dataset"
                    },
                )
            else:
                self._dataset_config["fs_args"] = deepcopy(self._fs_args)

        self._filepath_arg = filepath_arg
        if self._filepath_arg in self._dataset_config:
            warn(
                "`{}` key must not be specified in the dataset definition as it "
                "will be overwritten by partition path".format(
                    self._filepath_arg))

        self._load_args = deepcopy(load_args) or {}
        self._sep = self._filesystem.sep
        # since some filesystem implementations may implement a global cache
        self._invalidate_caches()

    @property
    def _filesystem(self):
        # for performance reasons
        import fsspec  # pylint: disable=import-outside-toplevel

        protocol = "s3" if self._protocol in S3_PROTOCOLS else self._protocol
        return fsspec.filesystem(protocol, **self._credentials,
                                 **self._fs_args)

    @property
    def _normalized_path(self) -> str:
        if self._protocol in S3_PROTOCOLS:
            return urlparse(self._path)._replace(scheme="s3").geturl()
        return self._path

    @cachedmethod(cache=operator.attrgetter("_partition_cache"))
    def _list_partitions(self) -> List[str]:
        return [
            path for path in self._filesystem.find(self._normalized_path, **
                                                   self._load_args)
            if path.endswith(self._filename_suffix)
        ]

    def _join_protocol(self, path: str) -> str:
        if self._path.startswith(
                self._protocol) and not path.startswith(self._protocol):
            return f"{self._protocol}://{path}"
        return path

    def _partition_to_path(self, path: str):
        dir_path = self._path.rstrip(self._sep)
        path = path.lstrip(self._sep)
        full_path = self._sep.join([dir_path, path]) + self._filename_suffix
        return full_path

    def _path_to_partition(self, path: str) -> str:
        dir_path = self._filesystem._strip_protocol(self._normalized_path)
        path = path.split(dir_path, 1).pop().lstrip(self._sep)
        if self._filename_suffix and path.endswith(self._filename_suffix):
            path = path[:-len(self._filename_suffix)]
        return path

    def _load(self) -> Dict[str, Callable[[], Any]]:
        partitions = {}

        for partition in self._list_partitions():
            kwargs = deepcopy(self._dataset_config)
            # join the protocol back since PySpark may rely on it
            kwargs[self._filepath_arg] = self._join_protocol(partition)
            dataset = self._dataset_type(**kwargs)  # type: ignore
            partition_id = self._path_to_partition(partition)
            partitions[partition_id] = dataset.load

        if not partitions:
            raise DataSetError(f"No partitions found in `{self._path}`")

        return partitions

    def _save(self, data: Dict[str, Any]) -> None:
        for partition_id, partition_data in sorted(data.items()):
            kwargs = deepcopy(self._dataset_config)
            partition = self._partition_to_path(partition_id)
            # join the protocol back since tools like PySpark may rely on it
            kwargs[self._filepath_arg] = self._join_protocol(partition)
            dataset = self._dataset_type(**kwargs)  # type: ignore
            dataset.save(partition_data)
        self._invalidate_caches()

    def _describe(self) -> Dict[str, Any]:
        clean_dataset_config = ({
            k: v
            for k, v in self._dataset_config.items() if k != CREDENTIALS_KEY
        } if isinstance(self._dataset_config, dict) else self._dataset_config)
        return dict(
            path=self._path,
            dataset_type=self._dataset_type.__name__,
            dataset_config=clean_dataset_config,
        )

    def _invalidate_caches(self):
        self._partition_cache.clear()
        self._filesystem.invalidate_cache(self._normalized_path)

    def _exists(self) -> bool:
        return bool(self._list_partitions())

    def _release(self) -> None:
        super()._release()
        self._invalidate_caches()
Ejemplo n.º 2
0
class AbstractVersionedDataSet(AbstractDataSet, abc.ABC):
    """
    ``AbstractVersionedDataSet`` is the base class for all versioned data set
    implementations. All data sets that implement versioning should extend this
    abstract class and implement the methods marked as abstract.

    Example:
    ::

        >>> from pathlib import Path, PurePosixPath
        >>> import pandas as pd
        >>> from kedro.io import AbstractVersionedDataSet
        >>>
        >>>
        >>> class MyOwnDataSet(AbstractVersionedDataSet):
        >>>     def __init__(self, filepath, version, param1, param2=True):
        >>>         super().__init__(PurePosixPath(filepath), version)
        >>>         self._param1 = param1
        >>>         self._param2 = param2
        >>>
        >>>     def _load(self) -> pd.DataFrame:
        >>>         load_path = self._get_load_path()
        >>>         return pd.read_csv(load_path)
        >>>
        >>>     def _save(self, df: pd.DataFrame) -> None:
        >>>         save_path = self._get_save_path()
        >>>         df.to_csv(str(save_path))
        >>>
        >>>     def _exists(self) -> bool:
        >>>         path = self._get_load_path()
        >>>         return Path(path.as_posix()).exists()
        >>>
        >>>     def _describe(self):
        >>>         return dict(version=self._version, param1=self._param1, param2=self._param2)

    Example catalog.yml specification:
    ::

        my_dataset:
            type: <path-to-my-own-dataset>.MyOwnDataSet
            filepath: data/01_raw/my_data.csv
            versioned: true
            param1: <param1-value> # param1 is a required argument
            # param2 will be True by default
    """

    def __init__(
        self,
        filepath: PurePosixPath,
        version: Optional[Version],
        exists_function: Callable[[str], bool] = None,
        glob_function: Callable[[str], List[str]] = None,
    ):
        """Creates a new instance of ``AbstractVersionedDataSet``.

        Args:
            filepath: Filepath in POSIX format to a file.
            version: If specified, should be an instance of
                ``kedro.io.core.Version``. If its ``load`` attribute is
                None, the latest version will be loaded. If its ``save``
                attribute is None, save version will be autogenerated.
            exists_function: Function that is used for determining whether
                a path exists in a filesystem.
            glob_function: Function that is used for finding all paths
                in a filesystem, which match a given pattern.
        """
        self._filepath = filepath
        self._version = version
        self._exists_function = exists_function or _local_exists
        self._glob_function = glob_function or iglob
        # 1 entry for load version, 1 for save version
        self._version_cache = Cache(maxsize=2)

    # 'key' is set to prevent cache key overlapping for load and save:
    # https://cachetools.readthedocs.io/en/stable/#cachetools.cachedmethod
    @cachedmethod(cache=attrgetter("_version_cache"), key=partial(hashkey, "load"))
    def _fetch_latest_load_version(self) -> str:
        # When load version is unpinned, fetch the most recent existing
        # version from the given path.
        pattern = str(self._get_versioned_path("*"))
        version_paths = sorted(self._glob_function(pattern), reverse=True)
        most_recent = next(
            (path for path in version_paths if self._exists_function(path)), None
        )

        if not most_recent:
            raise VersionNotFoundError(f"Did not find any versions for {self}")

        return PurePath(most_recent).parent.name

    # 'key' is set to prevent cache key overlapping for load and save:
    # https://cachetools.readthedocs.io/en/stable/#cachetools.cachedmethod
    @cachedmethod(cache=attrgetter("_version_cache"), key=partial(hashkey, "save"))
    def _fetch_latest_save_version(self) -> str:  # pylint: disable=no-self-use
        """Generate and cache the current save version"""
        return generate_timestamp()

    def resolve_load_version(self) -> Optional[str]:
        """Compute the version the dataset should be loaded with."""
        if not self._version:
            return None
        if self._version.load:
            return self._version.load
        return self._fetch_latest_load_version()

    def _get_load_path(self) -> PurePosixPath:
        if not self._version:
            # When versioning is disabled, load from original filepath
            return self._filepath

        load_version = self.resolve_load_version()
        return self._get_versioned_path(load_version)  # type: ignore

    def resolve_save_version(self) -> Optional[str]:
        """Compute the version the dataset should be saved with."""
        if not self._version:
            return None
        if self._version.save:
            return self._version.save
        return self._fetch_latest_save_version()

    def _get_save_path(self) -> PurePosixPath:
        if not self._version:
            # When versioning is disabled, return original filepath
            return self._filepath

        save_version = self.resolve_save_version()
        versioned_path = self._get_versioned_path(save_version)  # type: ignore

        if self._exists_function(str(versioned_path)):
            raise DataSetError(
                f"Save path `{versioned_path}` for {str(self)} must not exist if "
                f"versioning is enabled."
            )

        return versioned_path

    def _get_versioned_path(self, version: str) -> PurePosixPath:
        return self._filepath / version / self._filepath.name

    def load(self) -> Any:
        self.resolve_load_version()  # Make sure last load version is set
        return super().load()

    def save(self, data: Any) -> None:
        self._version_cache.clear()
        save_version = self.resolve_save_version()  # Make sure last save version is set
        try:
            super().save(data)
        except (FileNotFoundError, NotADirectoryError) as err:
            # FileNotFoundError raised in Win, NotADirectoryError raised in Unix
            _default_version = "YYYY-MM-DDThh.mm.ss.sssZ"
            raise DataSetError(
                f"Cannot save versioned dataset `{self._filepath.name}` to "
                f"`{self._filepath.parent.as_posix()}` because a file with the same "
                f"name already exists in the directory. This is likely because "
                f"versioning was enabled on a dataset already saved previously. Either "
                f"remove `{self._filepath.name}` from the directory or manually "
                f"convert it into a versioned dataset by placing it in a versioned "
                f"directory (e.g. with default versioning format "
                f"`{self._filepath.as_posix()}/{_default_version}/{self._filepath.name}"
                f"`)."
            ) from err

        load_version = self.resolve_load_version()
        if load_version != save_version:
            warnings.warn(
                _CONSISTENCY_WARNING.format(save_version, load_version, str(self))
            )

    def exists(self) -> bool:
        """Checks whether a data set's output already exists by calling
        the provided _exists() method.

        Returns:
            Flag indicating whether the output already exists.

        Raises:
            DataSetError: when underlying exists method raises error.

        """
        self._logger.debug("Checking whether target of %s exists", str(self))
        try:
            return self._exists()
        except VersionNotFoundError:
            return False
        except Exception as exc:  # SKIP_IF_NO_SPARK
            message = (
                f"Failed during exists check for data set {str(self)}.\n{str(exc)}"
            )
            raise DataSetError(message) from exc

    def _release(self) -> None:
        super()._release()
        self._version_cache.clear()
Ejemplo n.º 3
0
 def clear(self):
     with self.__timer as time:
         self.expire(time)
         Cache.clear(self)