예제 #1
0
def test_get_databricks_runtime_no_spark_session():
    with mock.patch(
            "mlflow.utils.databricks_utils._get_active_spark_session",
            return_value=None), mock.patch(
                "mlflow.utils.databricks_utils.is_in_databricks_notebook",
                return_value=True):
        runtime = get_databricks_runtime()
        assert runtime is None
예제 #2
0
 def to_dict(self):
     """Serialize the model to a dictionary."""
     res = {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
     databricks_runtime = get_databricks_runtime()
     if databricks_runtime:
         res["databricks_runtime"] = databricks_runtime
     if self.signature is not None:
         res["signature"] = self.signature.to_dict()
     if self.saved_input_example_info is not None:
         res["saved_input_example_info"] = self.saved_input_example_info
     return res
예제 #3
0
 def to_dict(self):
     """Serialize the model to a dictionary."""
     res = {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
     databricks_runtime = get_databricks_runtime()
     if databricks_runtime:
         res["databricks_runtime"] = databricks_runtime
     if self.signature is not None:
         res["signature"] = self.signature.to_dict()
     if self.saved_input_example_info is not None:
         res["saved_input_example_info"] = self.saved_input_example_info
     if self.mlflow_version is None and _MLFLOW_VERSION_KEY in res:
         res.pop(_MLFLOW_VERSION_KEY)
     return res
예제 #4
0
    def __init__(
            self,
            artifact_path=None,
            run_id=None,
            utc_time_created=None,
            flavors=None,
            signature=None,  # ModelSignature
            saved_input_example_info: Dict[str, Any] = None,
            **kwargs):
        # store model id instead of run_id and path to avoid confusion when model gets exported
        if run_id:
            self.run_id = run_id
            self.artifact_path = artifact_path

        databricks_runtime = get_databricks_runtime()
        if databricks_runtime:
            self.databricks_runtime = databricks_runtime
        self.utc_time_created = str(utc_time_created or datetime.utcnow())
        self.flavors = flavors if flavors is not None else {}
        self.signature = signature
        self.saved_input_example_info = saved_input_example_info
        self.__dict__.update(kwargs)
예제 #5
0
def test_get_databricks_runtime_nondb(mock_spark_session):
    runtime = get_databricks_runtime()
    assert runtime is None
    mock_spark_session.conf.get.assert_not_called()
예제 #6
0
def test_get_databricks_runtime_in_job(mock_spark_session):
    with mock.patch("mlflow.utils.databricks_utils.is_in_databricks_job",
                    return_value=True):
        get_databricks_runtime()
        mock_spark_session.conf.get.assert_called_once_with(
            "spark.databricks.clusterUsageTags.sparkVersion", default=None)