示例#1
0
class CachedDataSet(AbstractDataSet):
    """``CachedDataSet`` is a dataset wrapper which caches in memory the data saved,
    so that the user avoids io operations with slow storage media.

    You can also specify a ``CachedDataSet`` in catalog.yml:
    ::

        >>> test_ds:
        >>>    type: kedro.contrib.io.cached.CachedDataSet
        >>>    versioned: true
        >>>    dataset:
        >>>       type: CSVLocalDataSet
        >>>       filepath: example.csv

    Please note that if your dataset is versioned, this should be indicated in the wrapper
    class as shown above.
    """
    def __init__(self,
                 dataset: Union[AbstractDataSet, Dict],
                 version: Version = None):
        warn(
            "kedro.contrib.io.cached.CachedDataSet will be deprecated in future releases. "
            "Please refer to replacement dataset in kedro.io.",
            DeprecationWarning,
        )
        if isinstance(dataset, Dict):
            self._dataset = self._from_config(dataset, version)
        elif isinstance(dataset, AbstractDataSet):
            self._dataset = dataset
        else:
            raise ValueError(
                "The argument type of `dataset` should be either a dict/YAML "
                "representation of the dataset, or the actual dataset object.")
        self._cache = MemoryDataSet()

    def _release(self) -> None:
        self._cache.release()
        self._dataset.release()

    @staticmethod
    def _from_config(config, version):
        if VERSIONED_FLAG_KEY in config:
            raise ValueError(
                "Cached datasets should specify that they are versioned in the "
                "`CachedDataSet`, not in the wrapped dataset.")
        if version:
            config[VERSIONED_FLAG_KEY] = True
            return AbstractDataSet.from_config("_cached", config, version.load,
                                               version.save)
        return AbstractDataSet.from_config("_cached", config)

    def _describe(self) -> Dict[str, Any]:
        return {
            "dataset": self._dataset._describe(),  # pylint: disable=protected-access
            "cache": self._cache._describe(),  # pylint: disable=protected-access
        }

    def _load(self):
        data = self._cache.load() if self._cache.exists(
        ) else self._dataset.load()

        if not self._cache.exists():
            self._cache.save(data)

        return data

    def _save(self, data: Any) -> None:
        self._dataset.save(data)
        self._cache.save(data)

    def _exists(self) -> bool:
        return self._cache.exists() or self._dataset.exists()

    def __getstate__(self):
        # clearing the cache can be prevented by modifying
        # how parallel runner handles datasets (not trivial!)
        logging.getLogger(__name__).warning("%s: clearing cache to pickle.",
                                            str(self))
        self._cache.release()
        return self.__dict__
示例#2
0
class MLflowDataSet(AbstractDataSet):
    """``MLflowDataSet`` saves data to, and loads data from MLflow.
    You can also specify a ``MLflowDataSet`` in catalog.yml:
    ::
        >>> test_ds:
        >>>    type: MLflowDataSet
        >>>    dataset: pkl
    """
    def __init__(
        self,
        dataset: Union[AbstractDataSet, Dict, str] = None,
        filepath: str = None,
        dataset_name: str = None,
        saving_tracking_uri: str = None,
        saving_experiment_name: str = None,
        saving_run_id: str = None,
        loading_tracking_uri: str = None,
        loading_run_id: str = None,
        caching: bool = True,
        copy_mode: str = None,
        file_caching: bool = True,
    ):
        """
        dataset: A Kedro DataSet object or a dictionary used to save/load.
            If set to either {"json", "csv", "xls", "parquet", "png", "jpg", "jpeg", "img",
            "pkl", "txt", "yml", "yaml"}, dataset instance will be created accordingly with
            filepath arg.
            If set to "p", the value will be saved/loaded as a parameter (string).
            If set to "m", the value will be saved/loaded as a metric (numeric).
            If None (default), MLflow will not be used.
        filepath: File path, usually in local file system, to save to and load from.
            Used only if the dataset arg is a string.
            If None (default), `<temp directory>/<dataset_name arg>.<dataset arg>` is used.
        dataset_name: Used only if the dataset arg is a string and filepath arg is None.
            If None (default), Python object ID is used, but recommended to overwrite by
            a Kedro hook.
        saving_tracking_uri: MLflow Tracking URI to save to.
            If None (default), MLFLOW_TRACKING_URI environment variable is used.
        saving_experiment_name: MLflow experiment name to save to.
            If None (default), new experiment will not be created or started.
            Ignored if saving_run_id is set.
        saving_run_id: An existing MLflow experiment run ID to save to.
            If None (default), no existing experiment run will be resumed.
        loading_tracking_uri: MLflow Tracking URI to load from.
            If None (default), MLFLOW_TRACKING_URI environment variable is used.
        loading_run_id: MLflow experiment run ID to load from.
            If None (default), current active run ID will be used if available.
        caching: Enable caching if parallel runner is not used. True in default.
        copy_mode: The copy mode used to copy the data. Possible
            values are: "deepcopy", "copy" and "assign". If not
            provided, it is inferred based on the data type.
            Ignored if caching arg is False.
        file_caching: Attempt to use the file at filepath when loading if no cache found
            in memory. True in default.
        """
        self.dataset = dataset or MemoryDataSet()
        self.filepath = filepath
        self.dataset_name = dataset_name
        self.saving_tracking_uri = saving_tracking_uri
        self.saving_experiment_name = saving_experiment_name
        self.saving_run_id = saving_run_id
        self.loading_tracking_uri = loading_tracking_uri
        self.loading_run_id = loading_run_id
        self.caching = caching
        self.file_caching = file_caching
        self.copy_mode = copy_mode
        self._dataset_name = str(id(self))

        if isinstance(dataset, str):
            if (dataset not in {"p", "m"}) and (dataset not in dataset_dicts):
                raise ValueError(
                    "`dataset`: {} not supported. Specify one of {}.".format(
                        dataset, list(dataset_dicts.keys())))
        self._ready = False
        self._running_parallel = None
        self._cache = None

    def _init_dataset(self):

        if not getattr(self, "_ready", None):
            self._ready = True
            self.dataset_name = self.dataset_name or self._dataset_name
            _dataset = self.dataset
            if isinstance(self.dataset, str):
                dataset_dict = dataset_dicts.get(
                    self.dataset, {"type": "pickle.PickleDataSet"})
                dataset_dict["filepath"] = self.filepath = (
                    self.filepath or tempfile.gettempdir() + "/" +
                    self.dataset_name + "." + self.dataset)
                _dataset = dataset_dict

            if isinstance(_dataset, dict):
                self._dataset = AbstractDataSet.from_config(
                    self._dataset_name, _dataset)
            elif isinstance(_dataset, AbstractDataSet):
                self._dataset = _dataset
            else:
                raise ValueError(
                    "The argument type of `dataset` should be either a dict/YAML "
                    "representation of the dataset, or the actual dataset object."
                )

            _filepath = getattr(self._dataset, "_filepath", None)
            if _filepath:
                self.filepath = str(_filepath)

            if self.caching and (not self._running_parallel):
                self._cache = MemoryDataSet(copy_mode=self.copy_mode)

    def _release(self) -> None:
        self._init_dataset()
        self._dataset.release()
        if self._cache:
            self._cache.release()

    def _describe(self) -> Dict[str, Any]:
        return {
            "dataset":
            self._dataset._describe()
            if getattr(self, "_ready", None) else self.dataset,  # pylint: disable=protected-access
            "filepath":
            self.filepath,
            "saving_tracking_uri":
            self.saving_tracking_uri,
            "saving_experiment_name":
            self.saving_experiment_name,
            "saving_run_id":
            self.saving_run_id,
            "loading_tracking_uri":
            self.loading_tracking_uri,
            "loading_run_id":
            self.loading_run_id,
        }

    def _load(self):
        self._init_dataset()

        if self._cache and self._cache.exists():
            return self._cache.load()

        if self.file_caching and self._dataset.exists():
            return self._dataset.load()

        import mlflow

        client = mlflow.tracking.MlflowClient(
            tracking_uri=self.loading_tracking_uri)

        self.loading_run_id = self.loading_run_id or mlflow.active_run(
        ).info.run_id

        if self.dataset in {"p"}:
            run = client.get_run(self.loading_run_id)
            value = run.data.params.get(self.dataset_name, None)
            if value is None:
                raise KeyError("param '{}' not found in run_id '{}'.".format(
                    self.dataset_name, self.loading_run_id))

            PickleDataSet(filepath=self.filepath).save(value)

        elif self.dataset in {"m"}:
            run = client.get_run(self.loading_run_id)
            value = run.data.metrics.get(self.dataset_name, None)
            if value is None:
                raise KeyError("metric '{}' not found in run_id '{}'.".format(
                    self.dataset_name, self.loading_run_id))
            PickleDataSet(filepath=self.filepath).save(value)

        else:
            p = Path(self.filepath)

            dst_path = tempfile.gettempdir()
            downloaded_path = client.download_artifacts(
                run_id=self.loading_run_id,
                path=p.name,
                dst_path=dst_path,
            )
            if Path(downloaded_path) != p:
                Path(downloaded_path).rename(p)

        return self._dataset.load()

    def _save(self, data: Any) -> None:
        self._init_dataset()

        self._dataset.save(data)

        if self._cache:
            self._cache.save(data)

        if find_spec("mlflow"):
            import mlflow

            if self.saving_tracking_uri:
                mlflow.set_tracking_uri(self.saving_tracking_uri)

            if self.saving_run_id:
                mlflow.start_run(run_id=self.saving_run_id)

            elif self.saving_experiment_name:
                experiment_id = mlflow.get_experiment_by_name(
                    self.saving_experiment_name).experiment_id
                mlflow.start_run(run_id=self.saving_run_id,
                                 experiment_id=experiment_id)

            if self.dataset in {"p"}:
                mlflow_log_params({self.dataset_name: data})
            elif self.dataset in {"m"}:
                mlflow_log_metrics({self.dataset_name: data})
            else:
                mlflow_log_artifacts([self.filepath])

            if self.saving_run_id or self.saving_experiment_name:
                mlflow.end_run()

    def _exists(self) -> bool:
        self._init_dataset()
        if self._cache:
            return self._cache.exists()
        else:
            return False

    def __getstate__(self):
        return self.__dict__