def __init__( self, dataset: Union[AbstractDataSet, Dict], version: Version = None, copy_mode: str = None, ): """Creates a new instance of ``CachedDataSet`` pointing to the provided Python object. Args: dataset: A Kedro DataSet object or a dictionary to cache. version: If specified, should be an instance of ``kedro.io.core.Version``. If its ``load`` attribute is None, the latest version will be loaded. If its ``save`` attribute is None, save version will be autogenerated. copy_mode: The copy mode used to copy the data. Possible values are: "deepcopy", "copy" and "assign". If not provided, it is inferred based on the data type. Raises: ValueError: If the provided dataset is not a valid dict/YAML representation of a dataset or an actual dataset. """ if isinstance(dataset, dict): self._dataset = self._from_config(dataset, version) elif isinstance(dataset, AbstractDataSet): self._dataset = dataset else: raise ValueError( "The argument type of `dataset` should be either a dict/YAML " "representation of the dataset, or the actual dataset object.") self._cache = MemoryDataSet(copy_mode=copy_mode)
def add_feed_dict(self, feed_dict: Dict[str, Any], replace: bool = False) -> None: """Adds instances of ``MemoryDataSet``, containing the data provided through feed_dict. Args: feed_dict: A feed dict with data to be added in memory. replace: Specifies whether to replace an existing ``DataSet`` with the same name is allowed. Example: :: >>> import pandas as pd >>> >>> df = pd.DataFrame({'col1': [1, 2], >>> 'col2': [4, 5], >>> 'col3': [5, 6]}) >>> >>> io = DataCatalog() >>> io.add_feed_dict({ >>> 'data': df >>> }, replace=True) >>> >>> assert io.load("data").equals(df) """ for data_set_name in feed_dict: if isinstance(feed_dict[data_set_name], AbstractDataSet): data_set = feed_dict[data_set_name] else: data_set = MemoryDataSet(data=feed_dict[data_set_name]) self.add(data_set_name, data_set, replace)
class CachedDataSet(AbstractDataSet): """``CachedDataSet`` is a dataset wrapper which caches in memory the data saved, so that the user avoids io operations with slow storage media. You can also specify a ``CachedDataSet`` in catalog.yml: :: >>> test_ds: >>> type: CachedDataSet >>> versioned: true >>> dataset: >>> type: pandas.CSVDataSet >>> filepath: example.csv Please note that if your dataset is versioned, this should be indicated in the wrapper class as shown above. """ def __init__( self, dataset: Union[AbstractDataSet, Dict], version: Version = None, copy_mode: str = None, ): """Creates a new instance of ``CachedDataSet`` pointing to the provided Python object. Args: dataset: A Kedro DataSet object or a dictionary to cache. version: If specified, should be an instance of ``kedro.io.core.Version``. If its ``load`` attribute is None, the latest version will be loaded. If its ``save`` attribute is None, save version will be autogenerated. copy_mode: The copy mode used to copy the data. Possible values are: "deepcopy", "copy" and "assign". If not provided, it is inferred based on the data type. Raises: ValueError: If the provided dataset is not a valid dict/YAML representation of a dataset or an actual dataset. """ if isinstance(dataset, dict): self._dataset = self._from_config(dataset, version) elif isinstance(dataset, AbstractDataSet): self._dataset = dataset else: raise ValueError( "The argument type of `dataset` should be either a dict/YAML " "representation of the dataset, or the actual dataset object.") self._cache = MemoryDataSet(copy_mode=copy_mode) def _release(self) -> None: self._cache.release() self._dataset.release() @staticmethod def _from_config(config, version): if VERSIONED_FLAG_KEY in config: raise ValueError( "Cached datasets should specify that they are versioned in the " "`CachedDataSet`, not in the wrapped dataset.") if version: config[VERSIONED_FLAG_KEY] = True return AbstractDataSet.from_config("_cached", config, version.load, version.save) return AbstractDataSet.from_config("_cached", config) def _describe(self) -> Dict[str, Any]: return { "dataset": self._dataset._describe(), # pylint: disable=protected-access "cache": self._cache._describe(), # pylint: disable=protected-access } def _load(self): data = self._cache.load() if self._cache.exists( ) else self._dataset.load() if not self._cache.exists(): self._cache.save(data) return data def _save(self, data: Any) -> None: self._dataset.save(data) self._cache.save(data) def _exists(self) -> bool: return self._cache.exists() or self._dataset.exists() def __getstate__(self): # clearing the cache can be prevented by modifying # how parallel runner handles datasets (not trivial!) logging.getLogger(__name__).warning("%s: clearing cache to pickle.", str(self)) self._cache.release() return self.__dict__