예제 #1
0
    def _resolve_version(self, spec: AssetSpec) -> None:
        local_versions = self._list_local_versions(spec)
        logger.debug("Local versions", local_versions=local_versions)

        if spec.is_version_complete():
            return

        remote_versions = []
        if self.storage_provider:
            remote_versions = self.storage_provider.get_versions_info(
                spec.name)
            logger.debug("Fetched remote versions",
                         remote_versions=remote_versions)

        all_versions = spec.sort_versions(version_list=set(local_versions +
                                                           remote_versions))

        if not all_versions:
            if not spec.version:
                logger.debug("Asset has no version information")
                # no version is specified and none exist
                # in this case, the asset spec is likely a relative or absolute
                # path to a file/directory
                return None

            raise errors.LocalAssetDoesNotExistError(
                name=spec.name,
                version=spec.version,
                local_versions=local_versions,
            )

        # at least one version info is missing, update to the latest
        spec.set_latest_version(all_versions)
예제 #2
0
    def update(self, asset_path: str, name: str, version: str, dry_run=False):
        """
        Update an existing asset version
        """
        spec = AssetSpec(name=name, version=version)
        versions_object_name = self.get_versions_object_name(spec.name)
        if not self.driver.exists(versions_object_name):
            raise errors.AssetDoesNotExistError(spec.name)
        logger.info(
            "Updating asset",
            name=spec.name,
            version=spec.version,
            asset_path=asset_path,
        )
        versions_list = self.get_versions_info(spec.name)

        self.push(asset_path, spec.name, spec.version, dry_run=dry_run)

        with tempfile.TemporaryDirectory() as tmp_dir:
            versions_fn = os.path.join(tmp_dir, "versions.json")
            versions = spec.sort_versions([spec.version] + versions_list)
            with open(versions_fn, "w") as f:
                json.dump({"versions": versions}, f)
            logger.debug(
                "Pushing updated versions file",
                name=spec.name,
                versions=versions,
            )
            if not dry_run:
                self.driver.upload_object(versions_fn, versions_object_name)
예제 #3
0
def test_create_asset(spec_dict, valid):
    if valid:
        AssetSpec(**spec_dict)  #  major_minor is default system
        AssetSpec(**spec_dict, versioning="major_minor")
    else:
        with pytest.raises(errors.InvalidAssetSpecError):
            AssetSpec(**spec_dict)
예제 #4
0
def test_fetch_asset_version_no_storage_provider(version_asset_name, version,
                                                 versioning):
    manager = AssetsManager(assets_dir=os.path.join(
        TEST_DIR, "testdata", "test-bucket", "assets-prefix"))
    asset_name = os.path.join("category", version_asset_name)
    spec = AssetSpec(name=asset_name, version=version, versioning=versioning)

    asset_dict = manager._fetch_asset_version(
        spec=spec,
        _force_download=False,
    )
    assert asset_dict == {
        "from_cache": True,
        "version": version,
        "path": os.path.join(manager.assets_dir, asset_name, version),
    }

    with pytest.raises(errors.StorageDriverError):
        manager._fetch_asset_version(
            spec=spec,
            _force_download=True,
        )

    spec.name = os.path.join("not-existing-asset", version_asset_name)
    with pytest.raises(errors.LocalAssetDoesNotExistError):
        manager._fetch_asset_version(
            spec=spec,
            _force_download=False,
        )
예제 #5
0
def test_asset_spec_get_local_versions():
    spec = AssetSpec(name="name", versioning="major_minor")
    assert spec.get_local_versions("not_a_dir") == []
    asset_dir = [
        "testdata", "test-bucket", "assets-prefix", "category", "asset"
    ]
    local_path = os.path.join(tests.TEST_DIR, *asset_dir)
    assert spec.get_local_versions(local_path) == ["1.0", "0.1", "0.0"]
예제 #6
0
def test_create_asset():
    spec = AssetSpec(name="name",
                     version="2020-11-15T17-30-56Z",
                     versioning="simple_date")
    assert isinstance(spec.versioning, SimpleDateAssetsVersioningSystem)

    with pytest.raises(errors.InvalidVersionError):
        AssetSpec(name="name",
                  version="2020-11-15T17-30-56",
                  versioning="simple_date")
예제 #7
0
def test_asset_spec_is_version_complete():
    spec = AssetSpec(name="name", version="1.1", versioning="major_minor")
    assert spec.is_version_complete()

    spec = AssetSpec(name="name", version="1", versioning="major_minor")
    assert not spec.is_version_complete()

    spec = AssetSpec(name="name", versioning="major_minor")
    assert not spec.is_version_complete()
예제 #8
0
def test_asset_spec_sort_versions():
    spec = AssetSpec(name="name", versioning="simple_date")
    version_list = [
        "2021-11-15T17-30-56Z",
        "2020-11-15T17-30-56Z",
        "2021-10-15T17-30-56Z",
    ]
    result = [
        "2021-11-15T17-30-56Z",
        "2021-10-15T17-30-56Z",
        "2020-11-15T17-30-56Z",
    ]
    assert spec.sort_versions(version_list) == result
예제 #9
0
def test_asset_spec_get_local_versions():
    spec = AssetSpec(name="name", versioning="simple_date")
    assert spec.get_local_versions("not_a_dir") == []
    asset_dir = [
        "testdata",
        "test-bucket",
        "assets-prefix",
        "category",
        "simple_date_asset",
    ]
    local_path = os.path.join(tests.TEST_DIR, *asset_dir)
    assert spec.get_local_versions(local_path) == [
        "2021-11-15T17-31-06Z",
        "2021-11-14T18-00-00Z",
    ]
예제 #10
0
    def fetch_asset(
        self,
        spec: Union[AssetSpec, str],
        return_info=False,
        force_download: bool = None,
    ):
        if isinstance(spec, str):
            spec = cast(AssetSpec, AssetSpec.from_string(spec))
        if force_download is None and self.storage_provider:
            force_download = self.storage_provider.force_download

        logger.info(
            "Fetching asset",
            spec=spec,
            return_info=return_info,
            force_download=force_download,
        )

        asset_info = self._fetch_asset(spec, _force_download=force_download)
        logger.debug("Fetched asset", spec=spec, asset_info=asset_info)
        path = asset_info["path"]
        if not os.path.exists(path):  # pragma: no cover
            logger.error(
                "An unknown error occured when fetching asset."
                "The path does not exist.",
                path=path,
                spec=spec,
            )
            raise AssetFetchError(
                f"An unknown error occured when fetching asset {spec}."
                f"The path {path} does not exist.")
        if not return_info:
            return path
        return asset_info
예제 #11
0
def test_asset_spec_set_latest_version():
    spec = AssetSpec(name="a", versioning="simple_date")
    spec.set_latest_version(["2021-11-15T17-31-06Z", "2021-11-14T18-00-00Z"])
    assert spec.version == "2021-11-15T17-31-06Z"

    spec = AssetSpec(name="a",
                     version="2021-11-14T18-00-00Z",
                     versioning="simple_date")
    spec.set_latest_version(["2021-11-15T17-31-06Z", "2021-11-14T18-00-00Z"])
    assert spec.version == "2021-11-15T17-31-06Z"
예제 #12
0
def deploy_tf_models(lib, mode, config_name="config", verbose=False):
    manager = AssetsManager()
    configuration = lib.configuration
    model_paths = {}
    if mode == "remote":
        if not manager.storage_provider:
            raise ValueError(
                "A remote storage provider is required for `remote` mode")
        driver = manager.storage_provider.driver

    for model_name in lib.required_models:
        model_configuration = configuration[model_name]
        if not issubclass(model_configuration.model_type, TensorflowModel):
            logger.debug(f"Skipping non TF model `{model_name}`")
            continue
        if not model_configuration.asset:
            raise ValueError(
                f"TensorFlow model `{model_name}` does not have an asset")
        spec = AssetSpec.from_string(model_configuration.asset)
        if mode == "local-docker":
            model_paths[model_name] = "/".join(
                ("/config", spec.name, spec.version or "")) + (spec.sub_part
                                                               or "")
        elif mode == "local-process":
            model_paths[model_name] = os.path.join(
                manager.assets_dir,
                *spec.name.split("/"),
                f"{spec.version}",
                *(spec.sub_part.split("/") if spec.sub_part else ()),
            )
        elif mode == "remote":
            object_name = manager.storage_provider.get_object_name(
                spec.name, spec.version or "")
            model_paths[model_name] = driver.get_object_uri(object_name)

    if mode == "local-docker" or mode == "local-process":
        logger.info("Checking that local models are present.")
        download_assets(configuration=configuration,
                        required_models=lib.required_models)
    target = os.path.join(manager.assets_dir, f"{config_name}.config")

    if model_paths:
        logger.info(
            "Writing TF serving configuration locally.",
            config_name=config_name,
            target=target,
        )
        write_config(target, model_paths, verbose=verbose)
    else:
        logger.info(
            "Nothing to write",
            config_name=config_name,
            target=target,
        )
예제 #13
0
def new(asset_path, asset_spec, storage_prefix, dry_run):
    """
    Create a new asset.

    Create a new asset ASSET_SPEC with ASSET_PATH file.

    Will fail if asset exists (in this case use `update`).

    ASSET_PATH is the path to the file. The file can be local or on GCS
    (starting with gs://)

    ASSET_SPEC is and asset specification of the form
    [asset_name] (Major/minor version information is ignored)

    NB: [asset_name] can contain `/` too.
    """
    _check_asset_file_number(asset_path)
    manager = StorageProvider(prefix=storage_prefix, )
    print("Current assets manager:")
    print(f" - storage provider = `{manager.driver}`")
    print(f" - prefix = `{storage_prefix}`")

    print(f"Current asset: `{asset_spec}`")
    spec = AssetSpec.from_string(asset_spec)
    version = spec.versioning.get_initial_version()
    print(f" - name = `{spec.name}`")

    print(f"Push a new asset `{spec.name}` " f"with version `{version}`?")

    response = click.prompt("[y/N]")
    if response == "y":
        with tempfile.TemporaryDirectory() as tmp_dir:
            if asset_path.startswith("gs://"):
                asset_path = _download_object_or_prefix(
                    manager, asset_path=asset_path, destination_dir=tmp_dir)
            manager.new(asset_path, spec.name, version, dry_run)
    else:
        print("Aborting.")
예제 #14
0
def test_fetch_asset_version_with_sub_parts(version_asset_name, version,
                                            versioning, working_dir):
    manager = AssetsManager(assets_dir=os.path.join(
        TEST_DIR, "testdata", "test-bucket", "assets-prefix"))
    asset_name = os.path.join("category", version_asset_name)
    sub_part = "sub_part"
    spec = AssetSpec(name=asset_name,
                     version=version,
                     sub_part=sub_part,
                     versioning=versioning)

    # no _has_succeeded cache => fetch
    asset_dict = manager._fetch_asset_version(
        spec=spec,
        _force_download=False,
    )

    assert asset_dict == {
        "from_cache": True,
        "version": version,
        "path": os.path.join(manager.assets_dir, asset_name, version,
                             sub_part),
    }
예제 #15
0
 def _list_local_versions(self, spec: AssetSpec) -> List[str]:
     local_name = os.path.join(self.assets_dir, *spec.name.split("/"))
     return spec.get_local_versions(local_name)
예제 #16
0
def test_fetch_asset_version_with_storage_provider(version_asset_name, version,
                                                   versioning, working_dir):

    manager = AssetsManager(
        assets_dir=working_dir,
        storage_provider=StorageProvider(
            provider="local",
            bucket=os.path.join(TEST_DIR, "testdata", "test-bucket"),
            prefix="assets-prefix",
        ),
    )

    asset_name = os.path.join("category", version_asset_name)
    spec = AssetSpec(name=asset_name, version=version, versioning=versioning)

    # no _has_succeeded cache => fetch
    asset_dict = manager._fetch_asset_version(
        spec=spec,
        _force_download=False,
    )

    del asset_dict["meta"]  #  fetch meta data
    assert asset_dict == {
        "from_cache": False,
        "version": version,
        "path": os.path.join(working_dir, asset_name, version),
    }

    #  cache
    asset_dict = manager._fetch_asset_version(
        spec=spec,
        _force_download=False,
    )

    assert asset_dict == {
        "from_cache": True,
        "version": version,
        "path": os.path.join(working_dir, asset_name, version),
    }

    #  cache but force download
    asset_dict = manager._fetch_asset_version(
        spec=spec,
        _force_download=True,
    )

    del asset_dict["meta"]  #  fetch meta data
    assert asset_dict == {
        "from_cache": False,
        "version": version,
        "path": os.path.join(working_dir, asset_name, version),
    }

    # Re-Download asset when missing version
    os.remove(os.path.join(working_dir, asset_name, version))
    asset_dict = manager._fetch_asset_version(
        spec=spec,
        _force_download=False,
    )

    del asset_dict["meta"]  #  fetch meta data
    assert asset_dict == {
        "from_cache": False,
        "version": version,
        "path": os.path.join(working_dir, asset_name, version),
    }
예제 #17
0
def test_string_asset_spec(s, spec):
    assert AssetSpec.from_string(s, versioning="simple_date") == AssetSpec(
        versioning="simple_date", **spec)
예제 #18
0
def test_versioning_values():
    AssetSpec(name="a")
    AssetSpec(name="a", versioning="major_minor")
    AssetSpec(name="a", versioning="simple_date")
    with pytest.raises(errors.UnknownAssetsVersioningSystemError):
        AssetSpec(name="a", versioning="unk_versioning")
예제 #19
0
def update(asset_path, asset_spec, storage_prefix, bump_major, dry_run):
    """
    Update an existing asset using versioning system
    set in MODELKIT_ASSETS_VERSIONING_SYSTEM (major/minor by default)

    Update an existing asset ASSET_SPEC with ASSET_PATH file.


    By default will upload a new minor version.

    ASSET_PATH is the path to the file. The file can be local remote (AWS or GCS)
    (starting with gs:// or s3://)

    ASSET_SPEC is and asset specification of the form
    [asset_name]:[version]

    Specific documentation depends on the choosen model
    """

    _check_asset_file_number(asset_path)
    manager = StorageProvider(prefix=storage_prefix, )

    print("Current assets manager:")
    print(f" - storage provider = `{manager.driver}`")
    print(f" - prefix = `{storage_prefix}`")

    print(f"Current asset: `{asset_spec}`")
    versioning_system = os.environ.get("MODELKIT_ASSETS_VERSIONING_SYSTEM",
                                       "major_minor")
    spec = AssetSpec.from_string(asset_spec, versioning=versioning_system)
    print(f" - versioning system = `{versioning_system}` ")
    print(f" - name = `{spec.name}`")
    print(f" - version = `{spec.version}`")

    try:
        version_list = manager.get_versions_info(spec.name)
    except ObjectDoesNotExistError:
        print("Remote asset not found. Create it first using `new`")
        sys.exit(1)

    update_params = spec.versioning.get_update_cli_params(
        version=spec.version,
        version_list=version_list,
        bump_major=bump_major,
    )

    print(update_params["display"])
    new_version = spec.versioning.increment_version(
        spec.sort_versions(version_list),
        update_params["params"],
    )
    print(f"Push a new asset version `{new_version}` " f"for `{spec.name}`?")

    response = click.prompt("[y/N]")
    if response == "y":

        with tempfile.TemporaryDirectory() as tmp_dir:
            if asset_path.startswith("gs://"):
                asset_path = _download_object_or_prefix(
                    manager, asset_path=asset_path, destination_dir=tmp_dir)

            manager.update(
                asset_path,
                name=spec.name,
                version=new_version,
                dry_run=dry_run,
            )
    else:
        print("Aborting.")
예제 #20
0
def test_names(test, valid):
    if valid:
        AssetSpec.check_name_valid(test)
    else:
        with pytest.raises(errors.InvalidNameError):
            AssetSpec.check_name_valid(test)
예제 #21
0
def test_asset_spec_set_latest_version():
    spec = AssetSpec(name="a", versioning="major_minor")
    spec.set_latest_version(["3", "2.1", "1.3"])
    assert spec.version == "3"

    spec = AssetSpec(name="a", version="2", versioning="major_minor")
    spec.set_latest_version(["3", "2.1", "2.0", "1.3"])
    assert spec.version == "2.1"

    spec = AssetSpec(name="a", version="1.1", versioning="major_minor")
    spec.set_latest_version(["3", "2.1", "2.0", "1.3"])
    assert spec.version == "1.3"
예제 #22
0
    def _resolve_assets(self, model_name):
        """
        This function fetches assets for the current model and its dependent models
        and populates the assets_info dictionary with the paths.
        """
        logger.debug("Resolving asset for Model", model_name=model_name)
        configuration = self.configuration[model_name]
        # First, resolve assets from dependent models
        for dep_name in configuration.model_dependencies.values():
            self._resolve_assets(dep_name)

        if not configuration.asset:
            # If the model has no asset to load
            return

        model_settings = {
            **configuration.model_settings,
            **self.required_models.get(model_name, {}),
        }

        # If the asset is overriden in the model_settings
        if "asset_path" in model_settings:
            asset_path = model_settings.pop("asset_path")
            logger.debug(
                "Overriding asset from Model settings",
                model_name=model_name,
                asset_path=asset_path,
            )
            self.assets_info[configuration.asset] = AssetInfo(path=asset_path)

        asset_spec = AssetSpec.from_string(configuration.asset)

        # If the model's asset is overriden with environment variables
        venv = "MODELKIT_{}_FILE".format(
            re.sub(r"[\/\-\.]+", "_", asset_spec.name).upper())
        local_file = os.environ.get(venv)
        if local_file:
            logger.debug(
                "Overriding asset from environment variable",
                asset_name=asset_spec.name,
                path=local_file,
            )
            self.assets_info[configuration.asset] = AssetInfo(path=local_file)

        # The assets should be retrieved
        # possibly override version
        venv = "MODELKIT_{}_VERSION".format(
            re.sub(r"[\/\-\.]+", "_", asset_spec.name).upper())
        version = os.environ.get(venv)
        if version:
            logger.debug(
                "Overriding asset version from environment variable",
                asset_name=asset_spec.name,
                path=local_file,
            )
            asset_spec = AssetSpec.from_string(asset_spec.name + ":" + version)

        if self.override_assets_manager:
            try:
                self.assets_info[configuration.asset] = AssetInfo(
                    **self.override_assets_manager.fetch_asset(
                        spec=AssetSpec(name=asset_spec.name,
                                       sub_part=asset_spec.sub_part),
                        return_info=True,
                    ))
                logger.debug(
                    "Asset has been overriden",
                    name=asset_spec.name,
                )
            except modelkit.assets.errors.AssetDoesNotExistError:
                logger.debug(
                    "Asset not found in overriden prefix",
                    name=asset_spec.name,
                )

        if configuration.asset not in self.assets_info:
            self.assets_info[configuration.asset] = AssetInfo(
                **self.assets_manager.fetch_asset(asset_spec,
                                                  return_info=True))
예제 #23
0
def test_asset_spec_sort_versions(version_list, result):
    spec = AssetSpec(name="name", versioning="major_minor")
    assert spec.sort_versions(version_list) == result
예제 #24
0
def test_string_asset_spec(s, spec):
    assert AssetSpec.from_string(s) == AssetSpec(**spec)
    assert AssetSpec.from_string(s, versioning="major_minor") == AssetSpec(
        versioning="major_minor", **spec)
예제 #25
0
def test_generic_check_version_valid(test, valid):
    if valid:
        AssetSpec.check_version_valid(test)
    else:
        with pytest.raises(errors.InvalidVersionError):
            AssetSpec.check_version_valid(test)