def get_underlying_uri(uri): # Note: to support a registry URI that is different from the tracking URI here, # we'll need to add setting of registry URIs via environment variables. from mlflow.tracking import MlflowClient client = MlflowClient() (name, version, stage) = ModelsArtifactRepository._parse_uri(uri) if stage is not None: latest = client.get_latest_versions(name, [stage]) version = latest[0].version return client.get_model_version_download_uri(name, version)
def get_underlying_uri(uri): # Note: to support a registry URI that is different from the tracking URI here, # we'll need to add setting of registry URIs via environment variables. from mlflow.tracking import MlflowClient databricks_profile_uri = ( get_databricks_profile_uri_from_artifact_uri(uri) or mlflow.get_registry_uri() ) client = MlflowClient(registry_uri=databricks_profile_uri) (name, version) = get_model_name_and_version(client, uri) download_uri = client.get_model_version_download_uri(name, version) return add_databricks_profile_info_to_artifact_uri(download_uri, databricks_profile_uri)
def get_underlying_uri(uri): # Note: to support a registry URI that is different from the tracking URI here, # we'll need to add setting of registry URIs via environment variables. from mlflow.tracking import MlflowClient client = MlflowClient() (name, version, stage) = ModelsArtifactRepository._parse_uri(uri) if stage is not None: latest = client.get_latest_versions(name, [stage]) if len(latest) == 0: raise MlflowException("No versions of model with name '{name}' and " "stage '{stage}' found".format(name=name, stage=stage)) version = latest[0].version return client.get_model_version_download_uri(name, version)
import mlflow.sklearn from mlflow.tracking import MlflowClient from sklearn.ensemble import RandomForestRegressor if __name__ == "__main__": mlflow.set_tracking_uri("sqlite:///mlruns.db") params = {"n_estimators": 3, "random_state": 42} name = "RandomForestRegression" rfr = RandomForestRegressor(**params).fit([[0, 1]], [1]) # Log MLflow entities with mlflow.start_run() as run: mlflow.log_params(params) mlflow.sklearn.log_model(rfr, artifact_path="models/sklearn-model") # Register model name in the model registry client = MlflowClient() client.create_registered_model(name) # Create a new version of the rfr model under the registered model name model_uri = "runs:/{}/models/sklearn-model".format(run.info.run_id) mv = client.create_model_version(name, model_uri, run.info.run_id) artifact_uri = client.get_model_version_download_uri(name, mv.version) print("Download URI: {}".format(artifact_uri))
def main( mlflow_server: str, significance: float, ): # We start by setting the tracking uri to make sure the mlflow server is reachable mlflow.set_tracking_uri(mlflow_server) # We need to instantiate the MlflowClient class for certain operations mlflow_client = MlflowClient() # We create and set an experiment to group all runs mlflow.set_experiment("Model Comparison") # We create classification data and split it into training and testing sets X, y = make_classification( n_samples=10000, n_classes=2, n_features=20, n_informative=9, random_state=random_seed, ) X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, test_size=0.2) # We first train a Logistic regression model, log it in mlflow and then move it to the production stage with mlflow.start_run(): lr_model = LogisticRegression() lr_model.fit(X_train, y_train) y_pred = lr_model.predict(X_test) accuracy = accuracy_score(y_test, y_pred) mlflow.log_metric("accuracy", accuracy) mlflow.sklearn.log_model(lr_model, artifact_path="model", registered_model_name="Logistic Regression") mlflow_client.transition_model_version_stage(name="Logistic Regression", version=1, stage="Production") # We then train a Random Forest model, log it in mlflow and then move it to the staging stage with mlflow.start_run(): rf_model = RandomForestClassifier() rf_model.fit(X_train, y_train) y_pred = rf_model.predict(X_test) accuracy = accuracy_score(y_test, y_pred) mlflow.log_metric("accuracy", accuracy) mlflow.sklearn.log_model(rf_model, artifact_path="model", registered_model_name="Random Forest") mlflow_client.transition_model_version_stage(name="Random Forest", version=1, stage="Staging") del lr_model del rf_model # We finally load both models from MLFlow # and compare them using the McNemar test # We get the download uris of both models and then we load them lr_model_download_uri = mlflow_client.get_model_version_download_uri( name="Logistic Regression", version=1, ) rf_model_download_uri = mlflow_client.get_model_version_download_uri( name="Random Forest", version=1, ) lr_model = mlflow.sklearn.load_model(lr_model_download_uri) rf_model = mlflow.sklearn.load_model(rf_model_download_uri) y_pred_lr = lr_model.predict(X_test) y_pred_rf = rf_model.predict(X_test) contingency_table = mcnemar_table(y_test, y_pred_lr, y_pred_rf) _, p_value = mcnemar(contingency_table, corrected=True) if p_value < significance: # In this case we reject the null hypothesis that the two models' are similar # We then archive the logistic regression model # and move the random forest model to the Production stage print( f"p-value {p_value} smaller than significance level {significance}" ) accuracy_lr = accuracy_score(y_test, y_pred_lr) accuracy_rf = accuracy_score(y_test, y_pred_rf) if accuracy_lr < accuracy_rf: print( f"Random Forest model's accuracy, {accuracy_rf}, is greater than " f"the Logistic Regression model's accuracy, {accuracy_lr}") print( "Archiving logistic regression model and moving random forest model to production" ) mlflow_client.transition_model_version_stage( name="Logistic Regression", version=1, stage="Archived", ) mlflow_client.transition_model_version_stage( name="Random Forest", version=1, stage="Production", ) else: print( f"Random Forest model's accuracy, {accuracy_rf}, is less than or equal to " f"the Logistic Regression model's accuracy, {accuracy_lr}") print("Keeping logistic regression model in production") else: print( f"p-value {p_value} greater than significance level {significance}" ) print("Keeping logistic regression model in production")