def _init_server(backend_uri, root_artifact_uri): """ Launch a new REST server using the tracking store specified by backend_uri and root artifact directory specified by root_artifact_uri. :returns A tuple (url, process) containing the string URL of the server and a handle to the server process (a multiprocessing.Process object). """ kiwi.set_tracking_uri(None) server_port = get_safe_port() env = { BACKEND_STORE_URI_ENV_VAR: backend_uri, ARTIFACT_ROOT_ENV_VAR: path_to_local_file_uri( tempfile.mkdtemp(dir=local_file_uri_to_path(root_artifact_uri))), } with mock.patch.dict(os.environ, env): cmd = [ "python", "-c", 'from mlflow.server import app; app.run("{hostname}", {port})'. format(hostname=LOCALHOST, port=server_port) ] process = Popen(cmd) _await_server_up_or_die(server_port) url = "http://{hostname}:{port}".format(hostname=LOCALHOST, port=server_port) print("Launching tracking server against backend URI %s. Server URL: %s" % (backend_uri, url)) return url, process
def test_model_log(h2o_iris_model): h2o_model = h2o_iris_model.model old_uri = kiwi.get_tracking_uri() # should_start_run tests whether or not calling log_model() automatically starts a run. for should_start_run in [False, True]: with TempDir(chdr=True, remove_on_exit=True): try: artifact_path = "gbm_model" kiwi.set_tracking_uri("test") if should_start_run: kiwi.start_run() kiwi.h2o.log_model(h2o_model=h2o_model, artifact_path=artifact_path) model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=kiwi.active_run().info.run_id, artifact_path=artifact_path) # Load model h2o_model_loaded = kiwi.h2o.load_model(model_uri=model_uri) assert all( h2o_model_loaded.predict(h2o_iris_model.inference_data). as_data_frame() == h2o_model.predict( h2o_iris_model.inference_data).as_data_frame()) finally: kiwi.end_run() kiwi.set_tracking_uri(old_uri)
def test_sparkml_estimator_model_log(tmpdir, spark_model_estimator): # Print the coefficients and intercept for multinomial logistic regression old_tracking_uri = kiwi.get_tracking_uri() cnt = 0 # should_start_run tests whether or not calling log_model() automatically starts a run. for should_start_run in [False, True]: for dfs_tmp_dir in [None, os.path.join(str(tmpdir), "test")]: print("should_start_run =", should_start_run, "dfs_tmp_dir =", dfs_tmp_dir) try: tracking_dir = os.path.abspath(str(tmpdir.join("mlruns"))) kiwi.set_tracking_uri("file://%s" % tracking_dir) if should_start_run: kiwi.start_run() artifact_path = "model%d" % cnt cnt += 1 sparkm.log_model( artifact_path=artifact_path, spark_model=spark_model_estimator.model, dfs_tmpdir=dfs_tmp_dir) model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=kiwi.active_run().info.run_id, artifact_path=artifact_path) # test reloaded model reloaded_model = sparkm.load_model(model_uri=model_uri, dfs_tmpdir=dfs_tmp_dir) preds_df = reloaded_model.transform(spark_model_estimator.spark_df) preds = [x.prediction for x in preds_df.select("prediction").collect()] assert spark_model_estimator.predictions == preds finally: kiwi.end_run() kiwi.set_tracking_uri(old_tracking_uri) x = dfs_tmp_dir or sparkm.DFS_TMP shutil.rmtree(x) shutil.rmtree(tracking_dir)
def tracking_uri_mock(tmpdir, request): try: if 'notrackingurimock' not in request.keywords: tracking_uri = path_to_local_sqlite_uri( os.path.join(tmpdir.strpath, 'mlruns')) kiwi.set_tracking_uri(tracking_uri) os.environ["MLFLOW_TRACKING_URI"] = tracking_uri yield tmpdir finally: kiwi.set_tracking_uri(None) if 'notrackingurimock' not in request.keywords: del os.environ["MLFLOW_TRACKING_URI"]
def test_docker_project_tracking_uri_propagation(ProfileConfigProvider, tmpdir, tracking_uri, expected_command_segment, docker_example_base_image): # pylint: disable=unused-argument mock_provider = mock.MagicMock() mock_provider.get_config.return_value = \ DatabricksConfig("host", "user", "pass", None, insecure=True) ProfileConfigProvider.return_value = mock_provider # Create and mock local tracking directory local_tracking_dir = os.path.join(tmpdir.strpath, "mlruns") if tracking_uri is None: tracking_uri = local_tracking_dir old_uri = kiwi.get_tracking_uri() try: kiwi.set_tracking_uri(tracking_uri) with mock.patch("mlflow.tracking._tracking_service.utils._get_store" ) as _get_store_mock: _get_store_mock.return_value = file_store.FileStore( local_tracking_dir) kiwi.projects.run( TEST_DOCKER_PROJECT_DIR, experiment_id=file_store.FileStore.DEFAULT_EXPERIMENT_ID) finally: kiwi.set_tracking_uri(old_uri)
def test_get_tracking_uri_for_run(): kiwi.set_tracking_uri("http://some-uri") assert databricks._get_tracking_uri_for_run() == "http://some-uri" kiwi.set_tracking_uri("databricks://profile") assert databricks._get_tracking_uri_for_run() == "databricks" kiwi.set_tracking_uri(None) with mock.patch.dict( os.environ, {kiwi.tracking._TRACKING_URI_ENV_VAR: "http://some-uri"}): assert kiwi.tracking._tracking_service.utils.get_tracking_uri( ) == "http://some-uri"
def http_tracking_uri_mock(): kiwi.set_tracking_uri("http://some-cool-uri") yield kiwi.set_tracking_uri(None)