예제 #1
0
    def test_version_str_repr(self, load_version, save_version):
        """Test that version is in string representation of the class instance
        when applicable."""
        filepath = "/tmp/test.png"
        ds = ImageDataSet(filepath=filepath)
        ds_versioned = ImageDataSet(filepath=filepath,
                                    version=Version(load_version,
                                                    save_version))
        assert filepath in str(ds)
        assert filepath in str(ds_versioned)

        assert "version" not in str(ds)
        ver_str = f"version=Version(load={load_version}, save='{save_version}')"
        assert ver_str in str(ds_versioned)
        assert "ImageDataSet" in str(ds_versioned)
        assert "ImageDataSet" in str(ds)
        assert "protocol" in str(ds_versioned)
        assert "protocol" in str(ds)
예제 #2
0
    def test_version_str_repr(self, load_version, save_version):
        """Test that version is in string representation of the class instance
        when applicable."""
        ds = ParquetDataSet(filepath=FILENAME)
        ds_versioned = ParquetDataSet(filepath=FILENAME,
                                      version=Version(load_version,
                                                      save_version))
        assert FILENAME in str(ds)
        assert "version" not in str(ds)

        assert FILENAME in str(ds_versioned)
        ver_str = "version=Version(load={}, save='{}')".format(
            load_version, save_version)
        assert ver_str in str(ds_versioned)
        assert "ParquetDataSet" in str(ds_versioned)
        assert "ParquetDataSet" in str(ds)
        assert "protocol" in str(ds_versioned)
        assert "protocol" in str(ds)
예제 #3
0
    def test_version_str_repr(self, load_version, save_version):
        """Test that version is in string representation of the class instance
        when applicable."""
        filepath = "test.txt"
        ds = TextDataSet(filepath=filepath)
        ds_versioned = TextDataSet(filepath=filepath,
                                   version=Version(load_version, save_version))
        assert filepath in str(ds)
        assert "version" not in str(ds)

        assert filepath in str(ds_versioned)
        ver_str = "version=Version(load={}, save='{}')".format(
            load_version, save_version)
        assert ver_str in str(ds_versioned)
        assert "TextDataSet" in str(ds_versioned)
        assert "TextDataSet" in str(ds)
        assert "protocol" in str(ds_versioned)
        assert "protocol" in str(ds)
예제 #4
0
    def test_version_str_repr(self, load_version, save_version):
        """Test that version is in string representation of the class instance
        when applicable."""
        filepath = "test.hdf"
        ds = HDFLocalDataSet(filepath=filepath, key="test_hdf")
        ds_versioned = HDFLocalDataSet(
            filepath=filepath,
            key="test_hdf",
            version=Version(load_version, save_version),
        )

        assert filepath in str(ds)
        assert "version" not in str(ds)

        assert filepath in str(ds_versioned)
        ver_str = "version=Version(load={}, save='{}')".format(
            load_version, save_version)
        assert ver_str in str(ds_versioned)
예제 #5
0
    def load(self, name: str, version: str = None) -> Any:
        """Loads a registered data set.

        Args:
            name: A data set to be loaded.
            version: Optional version to be loaded.

        Returns:
            The loaded data as configured.

        Raises:
            DataSetNotFoundError: When a data set with the given name
                has not yet been registered.

        Example:
        ::

            >>> from kedro.io import DataCatalog
            >>> from kedro.extras.datasets.pandas import CSVDataSet
            >>>
            >>> cars = CSVDataSet(filepath="cars.csv",
            >>>                   load_args=None,
            >>>                   save_args={"index": False})
            >>> io = DataCatalog(data_sets={'cars': cars})
            >>>
            >>> df = io.load("cars")
        """
        if name not in self._data_sets:
            raise DataSetNotFoundError(
                "DataSet '{}' not found in the catalog".format(name))

        self._logger.info("Loading data from `%s` (%s)...", name,
                          type(self._data_sets[name]).__name__)

        version = Version(version, None) if version else None
        func = self._get_transformed_dataset_function(name, "load", version)
        result = func()

        load_version = (version.load if version else
                        self._data_sets[name].get_last_load_version())
        # Log only if versioning is enabled for the data set
        if self._journal and load_version:
            self._journal.log_catalog(name, "load", load_version)
        return result
예제 #6
0
    def load(self, name: str, version: str = None) -> Any:
        """Loads a registered data set.

        Args:
            name: A data set to be loaded.
            version: Optional argument for concrete data version to be loaded.
                Works only with versioned datasets.

        Returns:
            The loaded data as configured.

        Raises:
            DataSetNotFoundError: When a data set with the given name
                has not yet been registered.

        Example:
        ::

            >>> from kedro.io import DataCatalog
            >>> from kedro.extras.datasets.pandas import CSVDataSet
            >>>
            >>> cars = CSVDataSet(filepath="cars.csv",
            >>>                   load_args=None,
            >>>                   save_args={"index": False})
            >>> io = DataCatalog(data_sets={'cars': cars})
            >>>
            >>> df = io.load("cars")
        """
        load_version = Version(version, None) if version else None
        dataset = self._get_dataset(name, version=load_version)

        self._logger.info("Loading data from `%s` (%s)...", name,
                          type(dataset).__name__)

        func = self._get_transformed_dataset_function(name, "load", dataset)
        result = func()

        version = (dataset.resolve_load_version() if isinstance(
            dataset, AbstractVersionedDataSet) else None)

        # Log only if versioning is enabled for the data set
        if self._journal and version:
            self._journal.log_catalog(name, "load", version)
        return result
예제 #7
0
    def __init__(
        self,
        filepath: str,
        save_args: Dict[str, Any] = None,
        version: Version = Version(None, None),
        credentials: Dict[str, Any] = None,
        fs_args: Dict[str, Any] = None,
    ) -> None:
        """Creates a new instance of ``JSONDataSet`` pointing to a concrete JSON file
        on a specific filesystem.

        Args:
            filepath: Filepath in POSIX format to a JSON file prefixed with a protocol like `s3://`.
                If prefix is not provided, `file` protocol (local filesystem) will be used.
                The prefix should be any protocol supported by ``fsspec``.
                Note: `http(s)` doesn't support versioning.
            save_args: json options for saving JSON files (arguments passed
                into ```json.dump``). Here you can find all available arguments:
                https://docs.python.org/3/library/json.html
                All defaults are preserved, but "default_flow_style", which is set to False.
            version: If specified, should be an instance of
                ``kedro.io.core.Version``. If its ``load`` attribute is
                None, the latest version will be loaded. If its ``save``
                attribute is None, save version will be autogenerated.
                Versioning for this dataset is turned on by default and can not be turned off.
            credentials: Credentials required to get access to the underlying filesystem.
                E.g. for ``GCSFileSystem`` it should look like `{"token": None}`.
            fs_args: Extra arguments to pass into underlying filesystem class constructor
                (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as
                to pass to the filesystem's `open` method through nested keys
                `open_args_load` and `open_args_save`.
                Here you can find all available arguments for `open`:
                https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open
                All defaults are preserved, except `mode`, which is set to `r` when loading
                and to `w` when saving.
        """
        super().__init__(
            filepath=filepath,
            save_args=save_args,
            credentials=credentials,
            version=version,
            fs_args=fs_args,
        )
예제 #8
0
    def test_version_str_repr(self, load_version, save_version):
        """Test that version is in string representation of the class instance
        when applicable."""
        filepath = "test.json"
        ds = JSONDataSet(filepath=filepath)
        ds_versioned = JSONDataSet(filepath=filepath,
                                   version=Version(load_version, save_version))
        assert filepath in str(ds)
        assert "version" not in str(ds)

        assert filepath in str(ds_versioned)
        ver_str = f"version=Version(load={load_version}, save='{save_version}')"
        assert ver_str in str(ds_versioned)
        assert "JSONDataSet" in str(ds_versioned)
        assert "JSONDataSet" in str(ds)
        assert "protocol" in str(ds_versioned)
        assert "protocol" in str(ds)
        # Default save_args
        assert "save_args={'indent': 2}" in str(ds)
        assert "save_args={'indent': 2}" in str(ds_versioned)
예제 #9
0
    def test_version_str_repr(self, load_version, save_version):
        """Test that version is in string representation of the class instance
        when applicable."""
        filepath = "test.yaml"
        ds = YAMLDataSet(filepath=filepath)
        ds_versioned = YAMLDataSet(filepath=filepath,
                                   version=Version(load_version, save_version))
        assert filepath in str(ds)
        assert "version" not in str(ds)

        assert filepath in str(ds_versioned)
        ver_str = f"version=Version(load={load_version}, save='{save_version}')"
        assert ver_str in str(ds_versioned)
        assert "YAMLDataSet" in str(ds_versioned)
        assert "YAMLDataSet" in str(ds)
        assert "protocol" in str(ds_versioned)
        assert "protocol" in str(ds)
        # Default save_args
        assert "save_args={'default_flow_style': False}" in str(ds)
        assert "save_args={'default_flow_style': False}" in str(ds_versioned)
예제 #10
0
    def test_version_str_repr(self, load_version, save_version):
        """Test that version is in string representation of the class instance
        when applicable."""
        ds = HDFS3DataSet(filepath=FILENAME, bucket_name=BUCKET_NAME, key="test_hdf")
        ds_versioned = HDFS3DataSet(
            filepath=FILENAME,
            bucket_name=BUCKET_NAME,
            credentials=AWS_CREDENTIALS,
            key="test_hdf",
            version=Version(load_version, save_version),
        )

        assert FILENAME in str(ds)
        assert "version" not in str(ds)

        assert FILENAME in str(ds_versioned)
        ver_str = "version=Version(load={}, save='{}')".format(
            load_version, save_version
        )
        assert ver_str in str(ds_versioned)
예제 #11
0
    def test_version_str_repr(self, load_version, save_version):
        """Test that version is in string representation of the class instance
        when applicable."""
        filepath = "test.h5"
        ds = HDFDataSet(filepath=filepath, key=HDF_KEY)
        ds_versioned = HDFDataSet(
            filepath=filepath, key=HDF_KEY, version=Version(load_version, save_version)
        )
        assert filepath in str(ds)
        assert "version" not in str(ds)

        assert filepath in str(ds_versioned)
        ver_str = f"version=Version(load={load_version}, save='{save_version}')"
        assert ver_str in str(ds_versioned)
        assert "HDFDataSet" in str(ds_versioned)
        assert "HDFDataSet" in str(ds)
        assert "protocol" in str(ds_versioned)
        assert "protocol" in str(ds)
        assert "key" in str(ds_versioned)
        assert "key" in str(ds)
예제 #12
0
    def test_version_str_repr(self, load_version, save_version):
        """Test that version is in string representation of the class instance
        when applicable."""
        filepath = "test.csv"
        ds = CSVDataSet(filepath=filepath)
        ds_versioned = CSVDataSet(filepath=filepath,
                                  version=Version(load_version, save_version))
        assert filepath in str(ds)
        assert "version" not in str(ds)

        assert filepath in str(ds_versioned)
        ver_str = "version=Version(load={}, save='{}')".format(
            load_version, save_version)
        assert ver_str in str(ds_versioned)
        assert "CSVDataSet" in str(ds_versioned)
        assert "CSVDataSet" in str(ds)
        assert "protocol" in str(ds_versioned)
        assert "protocol" in str(ds)
        # Default save_args
        assert "save_args={'index': False}" in str(ds)
        assert "save_args={'index': False}" in str(ds_versioned)
예제 #13
0
    def test_version_str_repr(self, load_version, save_version):
        """Test that version is in string representation of the class instance
        when applicable."""
        filepath = "test"
        ds = EmailMessageDataSet(filepath=filepath)
        ds_versioned = EmailMessageDataSet(
            filepath=filepath, version=Version(load_version, save_version)
        )
        assert filepath in str(ds)
        assert "version" not in str(ds)

        assert filepath in str(ds_versioned)
        ver_str = f"version=Version(load={load_version}, save='{save_version}')"
        assert ver_str in str(ds_versioned)
        assert "EmailMessageDataSet" in str(ds_versioned)
        assert "EmailMessageDataSet" in str(ds)
        assert "protocol" in str(ds_versioned)
        assert "protocol" in str(ds)
        # Default parser_args
        assert f"parser_args={{'policy': {default}}}" in str(ds)
        assert f"parser_args={{'policy': {default}}}" in str(ds_versioned)
예제 #14
0
    def test_save(self, filepath_json, dummy_data, tmp_path, save_version):
        """Test saving and reloading the data set."""
        json_dataset = JSONDataSet(filepath=filepath_json,
                                   version=Version(None, save_version))
        json_dataset.save(dummy_data)

        actual_filepath = Path(json_dataset._filepath.as_posix())
        test_filepath = tmp_path / "locally_saved.json"

        test_filepath.parent.mkdir(parents=True, exist_ok=True)
        with open(test_filepath, "w", encoding="utf-8") as file:
            json.dump(dummy_data, file)

        with open(test_filepath, encoding="utf-8") as file:
            test_data = json.load(file)

        with open((actual_filepath / save_version / "test.json"),
                  encoding="utf-8") as actual_file:
            actual_data = json.load(actual_file)

        assert actual_data == test_data
        assert json_dataset._fs_open_args_load == {}
        assert json_dataset._fs_open_args_save == {"mode": "w"}
def test_save():
    tensorflow_dataset = GCSTensorflowModelDataSet(
        "gcs://coachneuro-dev-ml/models/basketball/front_legs.h5",
        version=Version(None, None),
        credentials={
            "id_token":
            "C:/Users/Ethan/CoachNeuro/coach-neuro-ml/conf/local"
            "/coachneuro-dev-ml.json"
        },
        fs_args={"project": "coachneuro-dev"})

    dataset = GCSCSVDataSet(
        "gcs://coachneuro-dev-ml/primary-csv-data/basketball/front_legs.csv", {
            "id_token":
            "C:/Users/Ethan/CoachNeuro/coach-neuro-ml/conf/local/coachneuro-dev-ml.json"
        }, {"project": "coachneuro-dev"})

    df = dataset.load()

    X_train, X_test, X_val, y_val, y_train, y_test = split_data_generic(
        df, {
            "test_size": 0.15,
            "val_size": 0.2,
            "random_state": 69
        })
    model = train_model_generic(X_train, y_train, X_val, y_val, {
        "input_dim": 66,
        "output_dim": 3
    })

    evaluate_model_generic(model, X_test, y_test, "BasketballFrontLegsModel")

    try:
        tensorflow_dataset.save(model)

    except Exception as e:
        assert False, f"save failed {e}"
예제 #16
0
def versioned_tf_model_dataset(filepath, load_version, save_version):
    return TensorFlowModelDataset(filepath=filepath,
                                  version=Version(load_version, save_version))
예제 #17
0
    def test_http_filesystem_no_versioning(self):
        pattern = r"HTTP\(s\) DataSet doesn't support versioning\."

        with pytest.raises(DataSetError, match=pattern):
            TensorFlowModelDataset(filepath="https://example.com/file.tf",
                                   version=Version(None, None))
예제 #18
0
def versioned_pickle_data_set(filepath_pkl, load_version, save_version):
    return PickleLocalDataSet(filepath=filepath_pkl,
                              version=Version(load_version, save_version))
예제 #19
0
def explicit_versioned_json_dataset(filepath_json, load_version, save_version):
    return JSONDataSet(filepath=filepath_json,
                       version=Version(load_version, save_version))
예제 #20
0
def explicit_versioned_metrics_dataset(filepath_json, load_version,
                                       save_version):
    return MetricsDataSet(filepath=filepath_json,
                          version=Version(load_version, save_version))
예제 #21
0
def versioned_txt_data_set(filepath_txt, load_version, save_version):
    return TextDataSet(filepath=filepath_txt,
                       version=Version(load_version, save_version))
예제 #22
0
    def test_http_filesystem_no_versioning(self):
        pattern = r"HTTP\(s\) DataSet doesn't support versioning\."

        with pytest.raises(DataSetError, match=pattern):
            GeoJSONDataSet(filepath="https://example/file.geojson",
                           version=Version(None, None))
예제 #23
0
def versioned_image_dataset(filepath_png, load_version, save_version):
    return ImageDataSet(filepath=filepath_png,
                        version=Version(load_version, save_version))
예제 #24
0
def versioned_parquet_data_set(data_path, load_version, save_version):
    return ParquetLocalDataSet(
        filepath=data_path, version=Version(load_version, save_version)
    )
예제 #25
0
def versioned_geojson_data_set(filepath, load_version, save_version):
    return GeoJSONDataSet(filepath=filepath,
                          version=Version(load_version, save_version))
예제 #26
0
def versioned_message_data_set(filepath_message, load_version, save_version):
    return EmailMessageDataSet(
        filepath=filepath_message, version=Version(load_version, save_version)
    )
예제 #27
0
def versioned_feather_data_set(filepath_feather, load_version, save_version):
    return FeatherDataSet(
        filepath=filepath_feather, version=Version(load_version, save_version)
    )
예제 #28
0
def versioned_xls_data_set(filepath_xls, load_version, save_version):
    return ExcelLocalDataSet(filepath=filepath_xls,
                             version=Version(load_version, save_version))
예제 #29
0
def versioned_yaml_data_set(filepath_yaml, load_version, save_version):
    return YAMLDataSet(filepath=filepath_yaml,
                       version=Version(load_version, save_version))
예제 #30
0
def versioned_csv_data_set(filepath, load_version, save_version):
    return CSVLocalDataSet(
        filepath=filepath,
        save_args={"sep": ","},
        version=Version(load_version, save_version),
    )