Ejemplo n.º 1
0
def model_packaging():
    """Additional items added to model API in log_model()."""
    return {
        'python_version': _utils.get_python_version(),
        'type': "sklearn",
        'deserialization': "cloudpickle",
    }
Ejemplo n.º 2
0
    def log_model(
        self,
        model,
        custom_modules=None,
        model_api=None,
        artifacts=None,
        overwrite=False,
    ):
        if self.has_model and not overwrite:
            raise ValueError(
                "model already exists; consider setting overwrite=True")

        if model_api and not isinstance(model_api, utils.ModelAPI):
            raise ValueError(
                "`model_api` must be `verta.utils.ModelAPI`, not {}".format(
                    type(model_api)))
        if artifacts is not None and not (isinstance(artifacts, list) and all(
                isinstance(artifact_key, six.string_types)
                for artifact_key in artifacts)):
            raise TypeError("`artifacts` must be list of str, not {}".format(
                type(artifacts)))

        # validate that `artifacts` are actually logged
        if artifacts:
            self._refresh_cache()
            run_msg = self._msg
            existing_artifact_keys = {
                artifact.key
                for artifact in run_msg.artifacts
            }
            unlogged_artifact_keys = set(artifacts) - existing_artifact_keys
            if unlogged_artifact_keys:
                raise ValueError(
                    "`artifacts` contains keys that have not been logged: {}".
                    format(sorted(unlogged_artifact_keys)))

        # associate artifact dependencies
        if artifacts:
            self.add_attribute(_MODEL_ARTIFACTS_ATTR_KEY,
                               artifacts,
                               overwrite=overwrite)

        serialized_model, method, model_type = _artifact_utils.serialize_model(
            model)

        if artifacts and model_type != "class":
            raise ValueError(
                "`artifacts` can only be provided if `model` is a class")

        # Create artifact message and update ModelVersion's message:
        model_msg = self._create_artifact_msg(
            self._MODEL_KEY,
            serialized_model,
            artifact_type=_CommonCommonService.ArtifactTypeEnum.MODEL,
            method=method,
            framework=model_type,
        )
        model_version_update = self.ModelVersionMessage(model=model_msg)
        self._update(model_version_update)

        # Upload the artifact to ModelDB:
        self._upload_artifact(
            self._MODEL_KEY,
            serialized_model,
            _CommonCommonService.ArtifactTypeEnum.MODEL,
        )

        # create and upload model API
        if model_type or model_api:  # only if provided or model is deployable
            if model_api is None:
                model_api = utils.ModelAPI()
            if "model_packaging" not in model_api:
                # add model serialization info to model_api
                model_api["model_packaging"] = {
                    "python_version": _utils.get_python_version(),
                    "type": model_type,
                    "deserialization": method,
                }
            self.log_artifact(_artifact_utils.MODEL_API_KEY, model_api,
                              overwrite, "json")

        # create and upload custom modules
        if model_type or custom_modules:  # only if provided or model is deployable
            # Log modules:
            custom_modules_artifact = self._custom_modules_as_artifact(
                custom_modules)
            self.log_artifact(
                _artifact_utils.CUSTOM_MODULES_KEY,
                custom_modules_artifact,
                overwrite,
                "zip",
            )