コード例 #1
0
ファイル: mlflow_registry.py プロジェクト: smellslikeml/rikai
def _get_model_info(uri: str, client: MlflowClient) -> str:
    """Transform the rikai model uri to something that mlflow understands"""
    parsed = urlparse(uri)
    try:
        client.get_registered_model(parsed.hostname)
        return _parse_model_ref(parsed, client)
    except MlflowException:
        return _parse_runid_ref(parsed, client)
コード例 #2
0
def test_client_registry_operations_raise_exception_with_unsupported_registry_store():
    """
    This test case ensures that Model Registry operations invoked on the `MlflowClient`
    fail with an informative error message when the registry store URI refers to a
    store that does not support Model Registry features (e.g., FileStore).
    """
    with TempDir() as tmp:
        client = MlflowClient(registry_uri=tmp.path())
        expected_failure_functions = [
            client._get_registry_client,
            lambda: client.create_registered_model("test"),
            lambda: client.get_registered_model("test"),
            lambda: client.create_model_version("test", "source", "run_id"),
            lambda: client.get_model_version("test", 1),
        ]
        for func in expected_failure_functions:
            with pytest.raises(MlflowException) as exc:
                func()
            assert exc.value.error_code == ErrorCode.Name(FEATURE_DISABLED)
コード例 #3
0
# MAGIC %md
# MAGIC ## Load parameters form MLFlow Registry

# COMMAND ----------

import mlflow
from mlflow.tracking import MlflowClient

# set model & registry params
model_name = 'Airbnb_Model'
cmr_uri = 'databricks://mlops_webinar:CMR'
dev_uri = 'databricks://mlops_webinar:dev'

# fetch model from registry
client = MlflowClient(tracking_uri=dev_uri, registry_uri=cmr_uri)
reg_mdl = client.get_registered_model(model_name)

# COMMAND ----------

# find latest production model
prod_mdl = None
for mdl in reg_mdl.latest_versions:
    print(
        f"{mdl.name} (v. {mdl.version}) in {mdl.current_stage} is {mdl.status}"
    )
    if mdl.current_stage == "Production":
        prod_mdl = mdl

if prod_mdl is None:
    raise RuntimeError("No Production Model Found")
コード例 #4
0
import mlflow
from mlflow.tracking import MlflowClient

if __name__ == "__main__":

    def print_registered_model_info(rm):
        print("name: {}".format(rm.name))
        print("tags: {}".format(rm.tags))
        print("description: {}".format(rm.description))

    name = "SocialTextAnalyzer"
    tags = {"nlp.framework": "Spark NLP"}
    desc = "This sentiment analysis model classifies the tone-happy, sad, angry."

    mlflow.set_tracking_uri("sqlite:///mlruns.db")
    client = MlflowClient()
    client.create_registered_model(name, tags, desc)
    print_registered_model_info(client.get_registered_model(name))
    print("--")

    # rename the model
    new_name = "SocialMediaTextAnalyzer"
    client.rename_registered_model(name, new_name)
    print_registered_model_info(client.get_registered_model(new_name))
コード例 #5
0
import mlflow
from mlflow.tracking import MlflowClient

if __name__ == "__main__":

    def print_model_info(rm):
        print("--")
        print("name: {}".format(rm.name))
        print("tags: {}".format(rm.tags))

    name = "SocialMediaTextAnalyzer"
    tags = {"nlp.framework1": "Spark NLP"}
    mlflow.set_tracking_uri("sqlite:///mlruns.db")
    client = MlflowClient()

    # Create registered model, set an additional tag, and fetch
    # update model info
    client.create_registered_model(name, tags)
    model = client.get_registered_model(name)
    print_model_info(model)
    client.set_registered_model_tag(name, "nlp.framework2", "VADER")
    model = client.get_registered_model(name)
    print_model_info(model)