def test_get_databricks_runtime_no_spark_session(): with mock.patch( "mlflow.utils.databricks_utils._get_active_spark_session", return_value=None), mock.patch( "mlflow.utils.databricks_utils.is_in_databricks_notebook", return_value=True): runtime = get_databricks_runtime() assert runtime is None
def to_dict(self): """Serialize the model to a dictionary.""" res = {k: v for k, v in self.__dict__.items() if not k.startswith("_")} databricks_runtime = get_databricks_runtime() if databricks_runtime: res["databricks_runtime"] = databricks_runtime if self.signature is not None: res["signature"] = self.signature.to_dict() if self.saved_input_example_info is not None: res["saved_input_example_info"] = self.saved_input_example_info return res
def to_dict(self): """Serialize the model to a dictionary.""" res = {k: v for k, v in self.__dict__.items() if not k.startswith("_")} databricks_runtime = get_databricks_runtime() if databricks_runtime: res["databricks_runtime"] = databricks_runtime if self.signature is not None: res["signature"] = self.signature.to_dict() if self.saved_input_example_info is not None: res["saved_input_example_info"] = self.saved_input_example_info if self.mlflow_version is None and _MLFLOW_VERSION_KEY in res: res.pop(_MLFLOW_VERSION_KEY) return res
def __init__( self, artifact_path=None, run_id=None, utc_time_created=None, flavors=None, signature=None, # ModelSignature saved_input_example_info: Dict[str, Any] = None, **kwargs): # store model id instead of run_id and path to avoid confusion when model gets exported if run_id: self.run_id = run_id self.artifact_path = artifact_path databricks_runtime = get_databricks_runtime() if databricks_runtime: self.databricks_runtime = databricks_runtime self.utc_time_created = str(utc_time_created or datetime.utcnow()) self.flavors = flavors if flavors is not None else {} self.signature = signature self.saved_input_example_info = saved_input_example_info self.__dict__.update(kwargs)
def test_get_databricks_runtime_nondb(mock_spark_session): runtime = get_databricks_runtime() assert runtime is None mock_spark_session.conf.get.assert_not_called()
def test_get_databricks_runtime_in_job(mock_spark_session): with mock.patch("mlflow.utils.databricks_utils.is_in_databricks_job", return_value=True): get_databricks_runtime() mock_spark_session.conf.get.assert_called_once_with( "spark.databricks.clusterUsageTags.sparkVersion", default=None)