Exemple #1
0
    def __init__(
        self,
        dataset: Union[AbstractDataSet, Dict],
        version: Version = None,
        copy_mode: str = None,
    ):
        """Creates a new instance of ``CachedDataSet`` pointing to the
        provided Python object.

        Args:
            dataset: A Kedro DataSet object or a dictionary to cache.
            version: If specified, should be an instance of
                ``kedro.io.core.Version``. If its ``load`` attribute is
                None, the latest version will be loaded. If its ``save``
                attribute is None, save version will be autogenerated.
            copy_mode: The copy mode used to copy the data. Possible
                values are: "deepcopy", "copy" and "assign". If not
                provided, it is inferred based on the data type.

        Raises:
            ValueError: If the provided dataset is not a valid dict/YAML
                representation of a dataset or an actual dataset.
        """
        if isinstance(dataset, dict):
            self._dataset = self._from_config(dataset, version)
        elif isinstance(dataset, AbstractDataSet):
            self._dataset = dataset
        else:
            raise ValueError(
                "The argument type of `dataset` should be either a dict/YAML "
                "representation of the dataset, or the actual dataset object.")
        self._cache = MemoryDataSet(copy_mode=copy_mode)
Exemple #2
0
    def add_feed_dict(self,
                      feed_dict: Dict[str, Any],
                      replace: bool = False) -> None:
        """Adds instances of ``MemoryDataSet``, containing the data provided
        through feed_dict.

        Args:
            feed_dict: A feed dict with data to be added in memory.
            replace: Specifies whether to replace an existing ``DataSet``
                with the same name is allowed.

        Example:
        ::

            >>> import pandas as pd
            >>>
            >>> df = pd.DataFrame({'col1': [1, 2],
            >>>                    'col2': [4, 5],
            >>>                    'col3': [5, 6]})
            >>>
            >>> io = DataCatalog()
            >>> io.add_feed_dict({
            >>>     'data': df
            >>> }, replace=True)
            >>>
            >>> assert io.load("data").equals(df)
        """
        for data_set_name in feed_dict:
            if isinstance(feed_dict[data_set_name], AbstractDataSet):
                data_set = feed_dict[data_set_name]
            else:
                data_set = MemoryDataSet(data=feed_dict[data_set_name])

            self.add(data_set_name, data_set, replace)
Exemple #3
0
class CachedDataSet(AbstractDataSet):
    """``CachedDataSet`` is a dataset wrapper which caches in memory the data saved,
    so that the user avoids io operations with slow storage media.

    You can also specify a ``CachedDataSet`` in catalog.yml:
    ::

        >>> test_ds:
        >>>    type: CachedDataSet
        >>>    versioned: true
        >>>    dataset:
        >>>       type: pandas.CSVDataSet
        >>>       filepath: example.csv

    Please note that if your dataset is versioned, this should be indicated in the wrapper
    class as shown above.
    """
    def __init__(
        self,
        dataset: Union[AbstractDataSet, Dict],
        version: Version = None,
        copy_mode: str = None,
    ):
        """Creates a new instance of ``CachedDataSet`` pointing to the
        provided Python object.

        Args:
            dataset: A Kedro DataSet object or a dictionary to cache.
            version: If specified, should be an instance of
                ``kedro.io.core.Version``. If its ``load`` attribute is
                None, the latest version will be loaded. If its ``save``
                attribute is None, save version will be autogenerated.
            copy_mode: The copy mode used to copy the data. Possible
                values are: "deepcopy", "copy" and "assign". If not
                provided, it is inferred based on the data type.

        Raises:
            ValueError: If the provided dataset is not a valid dict/YAML
                representation of a dataset or an actual dataset.
        """
        if isinstance(dataset, dict):
            self._dataset = self._from_config(dataset, version)
        elif isinstance(dataset, AbstractDataSet):
            self._dataset = dataset
        else:
            raise ValueError(
                "The argument type of `dataset` should be either a dict/YAML "
                "representation of the dataset, or the actual dataset object.")
        self._cache = MemoryDataSet(copy_mode=copy_mode)

    def _release(self) -> None:
        self._cache.release()
        self._dataset.release()

    @staticmethod
    def _from_config(config, version):
        if VERSIONED_FLAG_KEY in config:
            raise ValueError(
                "Cached datasets should specify that they are versioned in the "
                "`CachedDataSet`, not in the wrapped dataset.")
        if version:
            config[VERSIONED_FLAG_KEY] = True
            return AbstractDataSet.from_config("_cached", config, version.load,
                                               version.save)
        return AbstractDataSet.from_config("_cached", config)

    def _describe(self) -> Dict[str, Any]:
        return {
            "dataset": self._dataset._describe(),  # pylint: disable=protected-access
            "cache": self._cache._describe(),  # pylint: disable=protected-access
        }

    def _load(self):
        data = self._cache.load() if self._cache.exists(
        ) else self._dataset.load()

        if not self._cache.exists():
            self._cache.save(data)

        return data

    def _save(self, data: Any) -> None:
        self._dataset.save(data)
        self._cache.save(data)

    def _exists(self) -> bool:
        return self._cache.exists() or self._dataset.exists()

    def __getstate__(self):
        # clearing the cache can be prevented by modifying
        # how parallel runner handles datasets (not trivial!)
        logging.getLogger(__name__).warning("%s: clearing cache to pickle.",
                                            str(self))
        self._cache.release()
        return self.__dict__