コード例 #1
0
ファイル: manager.py プロジェクト: clustree/modelkit
    def fetch_asset(
        self,
        spec: Union[AssetSpec, str],
        return_info=False,
        force_download: bool = None,
    ):
        if isinstance(spec, str):
            spec = cast(AssetSpec, AssetSpec.from_string(spec))
        if force_download is None and self.storage_provider:
            force_download = self.storage_provider.force_download

        logger.info(
            "Fetching asset",
            spec=spec,
            return_info=return_info,
            force_download=force_download,
        )

        asset_info = self._fetch_asset(spec, _force_download=force_download)
        logger.debug("Fetched asset", spec=spec, asset_info=asset_info)
        path = asset_info["path"]
        if not os.path.exists(path):  # pragma: no cover
            logger.error(
                "An unknown error occured when fetching asset."
                "The path does not exist.",
                path=path,
                spec=spec,
            )
            raise AssetFetchError(
                f"An unknown error occured when fetching asset {spec}."
                f"The path {path} does not exist.")
        if not return_info:
            return path
        return asset_info
コード例 #2
0
def deploy_tf_models(lib, mode, config_name="config", verbose=False):
    manager = AssetsManager()
    configuration = lib.configuration
    model_paths = {}
    if mode == "remote":
        if not manager.storage_provider:
            raise ValueError(
                "A remote storage provider is required for `remote` mode")
        driver = manager.storage_provider.driver

    for model_name in lib.required_models:
        model_configuration = configuration[model_name]
        if not issubclass(model_configuration.model_type, TensorflowModel):
            logger.debug(f"Skipping non TF model `{model_name}`")
            continue
        if not model_configuration.asset:
            raise ValueError(
                f"TensorFlow model `{model_name}` does not have an asset")
        spec = AssetSpec.from_string(model_configuration.asset)
        if mode == "local-docker":
            model_paths[model_name] = "/".join(
                ("/config", spec.name, spec.version or "")) + (spec.sub_part
                                                               or "")
        elif mode == "local-process":
            model_paths[model_name] = os.path.join(
                manager.assets_dir,
                *spec.name.split("/"),
                f"{spec.version}",
                *(spec.sub_part.split("/") if spec.sub_part else ()),
            )
        elif mode == "remote":
            object_name = manager.storage_provider.get_object_name(
                spec.name, spec.version or "")
            model_paths[model_name] = driver.get_object_uri(object_name)

    if mode == "local-docker" or mode == "local-process":
        logger.info("Checking that local models are present.")
        download_assets(configuration=configuration,
                        required_models=lib.required_models)
    target = os.path.join(manager.assets_dir, f"{config_name}.config")

    if model_paths:
        logger.info(
            "Writing TF serving configuration locally.",
            config_name=config_name,
            target=target,
        )
        write_config(target, model_paths, verbose=verbose)
    else:
        logger.info(
            "Nothing to write",
            config_name=config_name,
            target=target,
        )
コード例 #3
0
def new(asset_path, asset_spec, storage_prefix, dry_run):
    """
    Create a new asset.

    Create a new asset ASSET_SPEC with ASSET_PATH file.

    Will fail if asset exists (in this case use `update`).

    ASSET_PATH is the path to the file. The file can be local or on GCS
    (starting with gs://)

    ASSET_SPEC is and asset specification of the form
    [asset_name] (Major/minor version information is ignored)

    NB: [asset_name] can contain `/` too.
    """
    _check_asset_file_number(asset_path)
    manager = StorageProvider(prefix=storage_prefix, )
    print("Current assets manager:")
    print(f" - storage provider = `{manager.driver}`")
    print(f" - prefix = `{storage_prefix}`")

    print(f"Current asset: `{asset_spec}`")
    spec = AssetSpec.from_string(asset_spec)
    version = spec.versioning.get_initial_version()
    print(f" - name = `{spec.name}`")

    print(f"Push a new asset `{spec.name}` " f"with version `{version}`?")

    response = click.prompt("[y/N]")
    if response == "y":
        with tempfile.TemporaryDirectory() as tmp_dir:
            if asset_path.startswith("gs://"):
                asset_path = _download_object_or_prefix(
                    manager, asset_path=asset_path, destination_dir=tmp_dir)
            manager.new(asset_path, spec.name, version, dry_run)
    else:
        print("Aborting.")
コード例 #4
0
def test_string_asset_spec(s, spec):
    assert AssetSpec.from_string(s) == AssetSpec(**spec)
    assert AssetSpec.from_string(s, versioning="major_minor") == AssetSpec(
        versioning="major_minor", **spec)
コード例 #5
0
def update(asset_path, asset_spec, storage_prefix, bump_major, dry_run):
    """
    Update an existing asset using versioning system
    set in MODELKIT_ASSETS_VERSIONING_SYSTEM (major/minor by default)

    Update an existing asset ASSET_SPEC with ASSET_PATH file.


    By default will upload a new minor version.

    ASSET_PATH is the path to the file. The file can be local remote (AWS or GCS)
    (starting with gs:// or s3://)

    ASSET_SPEC is and asset specification of the form
    [asset_name]:[version]

    Specific documentation depends on the choosen model
    """

    _check_asset_file_number(asset_path)
    manager = StorageProvider(prefix=storage_prefix, )

    print("Current assets manager:")
    print(f" - storage provider = `{manager.driver}`")
    print(f" - prefix = `{storage_prefix}`")

    print(f"Current asset: `{asset_spec}`")
    versioning_system = os.environ.get("MODELKIT_ASSETS_VERSIONING_SYSTEM",
                                       "major_minor")
    spec = AssetSpec.from_string(asset_spec, versioning=versioning_system)
    print(f" - versioning system = `{versioning_system}` ")
    print(f" - name = `{spec.name}`")
    print(f" - version = `{spec.version}`")

    try:
        version_list = manager.get_versions_info(spec.name)
    except ObjectDoesNotExistError:
        print("Remote asset not found. Create it first using `new`")
        sys.exit(1)

    update_params = spec.versioning.get_update_cli_params(
        version=spec.version,
        version_list=version_list,
        bump_major=bump_major,
    )

    print(update_params["display"])
    new_version = spec.versioning.increment_version(
        spec.sort_versions(version_list),
        update_params["params"],
    )
    print(f"Push a new asset version `{new_version}` " f"for `{spec.name}`?")

    response = click.prompt("[y/N]")
    if response == "y":

        with tempfile.TemporaryDirectory() as tmp_dir:
            if asset_path.startswith("gs://"):
                asset_path = _download_object_or_prefix(
                    manager, asset_path=asset_path, destination_dir=tmp_dir)

            manager.update(
                asset_path,
                name=spec.name,
                version=new_version,
                dry_run=dry_run,
            )
    else:
        print("Aborting.")
コード例 #6
0
def test_string_asset_spec(s, spec):
    assert AssetSpec.from_string(s, versioning="simple_date") == AssetSpec(
        versioning="simple_date", **spec)
コード例 #7
0
    def _resolve_assets(self, model_name):
        """
        This function fetches assets for the current model and its dependent models
        and populates the assets_info dictionary with the paths.
        """
        logger.debug("Resolving asset for Model", model_name=model_name)
        configuration = self.configuration[model_name]
        # First, resolve assets from dependent models
        for dep_name in configuration.model_dependencies.values():
            self._resolve_assets(dep_name)

        if not configuration.asset:
            # If the model has no asset to load
            return

        model_settings = {
            **configuration.model_settings,
            **self.required_models.get(model_name, {}),
        }

        # If the asset is overriden in the model_settings
        if "asset_path" in model_settings:
            asset_path = model_settings.pop("asset_path")
            logger.debug(
                "Overriding asset from Model settings",
                model_name=model_name,
                asset_path=asset_path,
            )
            self.assets_info[configuration.asset] = AssetInfo(path=asset_path)

        asset_spec = AssetSpec.from_string(configuration.asset)

        # If the model's asset is overriden with environment variables
        venv = "MODELKIT_{}_FILE".format(
            re.sub(r"[\/\-\.]+", "_", asset_spec.name).upper())
        local_file = os.environ.get(venv)
        if local_file:
            logger.debug(
                "Overriding asset from environment variable",
                asset_name=asset_spec.name,
                path=local_file,
            )
            self.assets_info[configuration.asset] = AssetInfo(path=local_file)

        # The assets should be retrieved
        # possibly override version
        venv = "MODELKIT_{}_VERSION".format(
            re.sub(r"[\/\-\.]+", "_", asset_spec.name).upper())
        version = os.environ.get(venv)
        if version:
            logger.debug(
                "Overriding asset version from environment variable",
                asset_name=asset_spec.name,
                path=local_file,
            )
            asset_spec = AssetSpec.from_string(asset_spec.name + ":" + version)

        if self.override_assets_manager:
            try:
                self.assets_info[configuration.asset] = AssetInfo(
                    **self.override_assets_manager.fetch_asset(
                        spec=AssetSpec(name=asset_spec.name,
                                       sub_part=asset_spec.sub_part),
                        return_info=True,
                    ))
                logger.debug(
                    "Asset has been overriden",
                    name=asset_spec.name,
                )
            except modelkit.assets.errors.AssetDoesNotExistError:
                logger.debug(
                    "Asset not found in overriden prefix",
                    name=asset_spec.name,
                )

        if configuration.asset not in self.assets_info:
            self.assets_info[configuration.asset] = AssetInfo(
                **self.assets_manager.fetch_asset(asset_spec,
                                                  return_info=True))