def model_packaging(): """Additional items added to model API in log_model().""" return { 'python_version': _utils.get_python_version(), 'type': "sklearn", 'deserialization': "cloudpickle", }
def log_model( self, model, custom_modules=None, model_api=None, artifacts=None, overwrite=False, ): if self.has_model and not overwrite: raise ValueError( "model already exists; consider setting overwrite=True") if model_api and not isinstance(model_api, utils.ModelAPI): raise ValueError( "`model_api` must be `verta.utils.ModelAPI`, not {}".format( type(model_api))) if artifacts is not None and not (isinstance(artifacts, list) and all( isinstance(artifact_key, six.string_types) for artifact_key in artifacts)): raise TypeError("`artifacts` must be list of str, not {}".format( type(artifacts))) # validate that `artifacts` are actually logged if artifacts: self._refresh_cache() run_msg = self._msg existing_artifact_keys = { artifact.key for artifact in run_msg.artifacts } unlogged_artifact_keys = set(artifacts) - existing_artifact_keys if unlogged_artifact_keys: raise ValueError( "`artifacts` contains keys that have not been logged: {}". format(sorted(unlogged_artifact_keys))) # associate artifact dependencies if artifacts: self.add_attribute(_MODEL_ARTIFACTS_ATTR_KEY, artifacts, overwrite=overwrite) serialized_model, method, model_type = _artifact_utils.serialize_model( model) if artifacts and model_type != "class": raise ValueError( "`artifacts` can only be provided if `model` is a class") # Create artifact message and update ModelVersion's message: model_msg = self._create_artifact_msg( self._MODEL_KEY, serialized_model, artifact_type=_CommonCommonService.ArtifactTypeEnum.MODEL, method=method, framework=model_type, ) model_version_update = self.ModelVersionMessage(model=model_msg) self._update(model_version_update) # Upload the artifact to ModelDB: self._upload_artifact( self._MODEL_KEY, serialized_model, _CommonCommonService.ArtifactTypeEnum.MODEL, ) # create and upload model API if model_type or model_api: # only if provided or model is deployable if model_api is None: model_api = utils.ModelAPI() if "model_packaging" not in model_api: # add model serialization info to model_api model_api["model_packaging"] = { "python_version": _utils.get_python_version(), "type": model_type, "deserialization": method, } self.log_artifact(_artifact_utils.MODEL_API_KEY, model_api, overwrite, "json") # create and upload custom modules if model_type or custom_modules: # only if provided or model is deployable # Log modules: custom_modules_artifact = self._custom_modules_as_artifact( custom_modules) self.log_artifact( _artifact_utils.CUSTOM_MODULES_KEY, custom_modules_artifact, overwrite, "zip", )