def test_save_modify_original_data(spark_data_frame): """Check that the data set object is not updated when the original SparkDataFrame is changed.""" memory_data_set = MemoryDataSet() memory_data_set.save(spark_data_frame) spark_data_frame = _update_spark_df(spark_data_frame, 1, 1, "new value") assert not _check_equals(memory_data_set.load(), spark_data_frame)
class CachedDataSet(AbstractDataSet): """``CachedDataSet`` is a dataset wrapper which caches in memory the data saved, so that the user avoids io operations with slow storage media. You can also specify a ``CachedDataSet`` in catalog.yml: :: >>> test_ds: >>> type: kedro.contrib.io.cached.CachedDataSet >>> versioned: true >>> dataset: >>> type: CSVLocalDataSet >>> filepath: example.csv Please note that if your dataset is versioned, this should be indicated in the wrapper class as shown above. """ def __init__(self, dataset: Union[AbstractDataSet, Dict], version: Version = None): warn( "kedro.contrib.io.cached.CachedDataSet will be deprecated in future releases. " "Please refer to replacement dataset in kedro.io.", DeprecationWarning, ) if isinstance(dataset, Dict): self._dataset = self._from_config(dataset, version) elif isinstance(dataset, AbstractDataSet): self._dataset = dataset else: raise ValueError( "The argument type of `dataset` should be either a dict/YAML " "representation of the dataset, or the actual dataset object.") self._cache = MemoryDataSet() def _release(self) -> None: self._cache.release() self._dataset.release() @staticmethod def _from_config(config, version): if VERSIONED_FLAG_KEY in config: raise ValueError( "Cached datasets should specify that they are versioned in the " "`CachedDataSet`, not in the wrapped dataset.") if version: config[VERSIONED_FLAG_KEY] = True return AbstractDataSet.from_config("_cached", config, version.load, version.save) return AbstractDataSet.from_config("_cached", config) def _describe(self) -> Dict[str, Any]: return { "dataset": self._dataset._describe(), # pylint: disable=protected-access "cache": self._cache._describe(), # pylint: disable=protected-access } def _load(self): data = self._cache.load() if self._cache.exists( ) else self._dataset.load() if not self._cache.exists(): self._cache.save(data) return data def _save(self, data: Any) -> None: self._dataset.save(data) self._cache.save(data) def _exists(self) -> bool: return self._cache.exists() or self._dataset.exists() def __getstate__(self): # clearing the cache can be prevented by modifying # how parallel runner handles datasets (not trivial!) logging.getLogger(__name__).warning("%s: clearing cache to pickle.", str(self)) self._cache.release() return self.__dict__
class MLflowDataSet(AbstractDataSet): """``MLflowDataSet`` saves data to, and loads data from MLflow. You can also specify a ``MLflowDataSet`` in catalog.yml: :: >>> test_ds: >>> type: MLflowDataSet >>> dataset: pkl """ def __init__( self, dataset: Union[AbstractDataSet, Dict, str] = None, filepath: str = None, dataset_name: str = None, saving_tracking_uri: str = None, saving_experiment_name: str = None, saving_run_id: str = None, loading_tracking_uri: str = None, loading_run_id: str = None, caching: bool = True, copy_mode: str = None, file_caching: bool = True, ): """ dataset: A Kedro DataSet object or a dictionary used to save/load. If set to either {"json", "csv", "xls", "parquet", "png", "jpg", "jpeg", "img", "pkl", "txt", "yml", "yaml"}, dataset instance will be created accordingly with filepath arg. If set to "p", the value will be saved/loaded as a parameter (string). If set to "m", the value will be saved/loaded as a metric (numeric). If None (default), MLflow will not be used. filepath: File path, usually in local file system, to save to and load from. Used only if the dataset arg is a string. If None (default), `<temp directory>/<dataset_name arg>.<dataset arg>` is used. dataset_name: Used only if the dataset arg is a string and filepath arg is None. If None (default), Python object ID is used, but recommended to overwrite by a Kedro hook. saving_tracking_uri: MLflow Tracking URI to save to. If None (default), MLFLOW_TRACKING_URI environment variable is used. saving_experiment_name: MLflow experiment name to save to. If None (default), new experiment will not be created or started. Ignored if saving_run_id is set. saving_run_id: An existing MLflow experiment run ID to save to. If None (default), no existing experiment run will be resumed. loading_tracking_uri: MLflow Tracking URI to load from. If None (default), MLFLOW_TRACKING_URI environment variable is used. loading_run_id: MLflow experiment run ID to load from. If None (default), current active run ID will be used if available. caching: Enable caching if parallel runner is not used. True in default. copy_mode: The copy mode used to copy the data. Possible values are: "deepcopy", "copy" and "assign". If not provided, it is inferred based on the data type. Ignored if caching arg is False. file_caching: Attempt to use the file at filepath when loading if no cache found in memory. True in default. """ self.dataset = dataset or MemoryDataSet() self.filepath = filepath self.dataset_name = dataset_name self.saving_tracking_uri = saving_tracking_uri self.saving_experiment_name = saving_experiment_name self.saving_run_id = saving_run_id self.loading_tracking_uri = loading_tracking_uri self.loading_run_id = loading_run_id self.caching = caching self.file_caching = file_caching self.copy_mode = copy_mode self._dataset_name = str(id(self)) if isinstance(dataset, str): if (dataset not in {"p", "m"}) and (dataset not in dataset_dicts): raise ValueError( "`dataset`: {} not supported. Specify one of {}.".format( dataset, list(dataset_dicts.keys()))) self._ready = False self._running_parallel = None self._cache = None def _init_dataset(self): if not getattr(self, "_ready", None): self._ready = True self.dataset_name = self.dataset_name or self._dataset_name _dataset = self.dataset if isinstance(self.dataset, str): dataset_dict = dataset_dicts.get( self.dataset, {"type": "pickle.PickleDataSet"}) dataset_dict["filepath"] = self.filepath = ( self.filepath or tempfile.gettempdir() + "/" + self.dataset_name + "." + self.dataset) _dataset = dataset_dict if isinstance(_dataset, dict): self._dataset = AbstractDataSet.from_config( self._dataset_name, _dataset) elif isinstance(_dataset, AbstractDataSet): self._dataset = _dataset else: raise ValueError( "The argument type of `dataset` should be either a dict/YAML " "representation of the dataset, or the actual dataset object." ) _filepath = getattr(self._dataset, "_filepath", None) if _filepath: self.filepath = str(_filepath) if self.caching and (not self._running_parallel): self._cache = MemoryDataSet(copy_mode=self.copy_mode) def _release(self) -> None: self._init_dataset() self._dataset.release() if self._cache: self._cache.release() def _describe(self) -> Dict[str, Any]: return { "dataset": self._dataset._describe() if getattr(self, "_ready", None) else self.dataset, # pylint: disable=protected-access "filepath": self.filepath, "saving_tracking_uri": self.saving_tracking_uri, "saving_experiment_name": self.saving_experiment_name, "saving_run_id": self.saving_run_id, "loading_tracking_uri": self.loading_tracking_uri, "loading_run_id": self.loading_run_id, } def _load(self): self._init_dataset() if self._cache and self._cache.exists(): return self._cache.load() if self.file_caching and self._dataset.exists(): return self._dataset.load() import mlflow client = mlflow.tracking.MlflowClient( tracking_uri=self.loading_tracking_uri) self.loading_run_id = self.loading_run_id or mlflow.active_run( ).info.run_id if self.dataset in {"p"}: run = client.get_run(self.loading_run_id) value = run.data.params.get(self.dataset_name, None) if value is None: raise KeyError("param '{}' not found in run_id '{}'.".format( self.dataset_name, self.loading_run_id)) PickleDataSet(filepath=self.filepath).save(value) elif self.dataset in {"m"}: run = client.get_run(self.loading_run_id) value = run.data.metrics.get(self.dataset_name, None) if value is None: raise KeyError("metric '{}' not found in run_id '{}'.".format( self.dataset_name, self.loading_run_id)) PickleDataSet(filepath=self.filepath).save(value) else: p = Path(self.filepath) dst_path = tempfile.gettempdir() downloaded_path = client.download_artifacts( run_id=self.loading_run_id, path=p.name, dst_path=dst_path, ) if Path(downloaded_path) != p: Path(downloaded_path).rename(p) return self._dataset.load() def _save(self, data: Any) -> None: self._init_dataset() self._dataset.save(data) if self._cache: self._cache.save(data) if find_spec("mlflow"): import mlflow if self.saving_tracking_uri: mlflow.set_tracking_uri(self.saving_tracking_uri) if self.saving_run_id: mlflow.start_run(run_id=self.saving_run_id) elif self.saving_experiment_name: experiment_id = mlflow.get_experiment_by_name( self.saving_experiment_name).experiment_id mlflow.start_run(run_id=self.saving_run_id, experiment_id=experiment_id) if self.dataset in {"p"}: mlflow_log_params({self.dataset_name: data}) elif self.dataset in {"m"}: mlflow_log_metrics({self.dataset_name: data}) else: mlflow_log_artifacts([self.filepath]) if self.saving_run_id or self.saving_experiment_name: mlflow.end_run() def _exists(self) -> bool: self._init_dataset() if self._cache: return self._cache.exists() else: return False def __getstate__(self): return self.__dict__