示例#1
0
def _init_server(backend_uri, root_artifact_uri):
    """
    Launch a new REST server using the tracking store specified by backend_uri and root artifact
    directory specified by root_artifact_uri.
    :returns A tuple (url, process) containing the string URL of the server and a handle to the
             server process (a multiprocessing.Process object).
    """
    kiwi.set_tracking_uri(None)
    server_port = get_safe_port()
    env = {
        BACKEND_STORE_URI_ENV_VAR:
        backend_uri,
        ARTIFACT_ROOT_ENV_VAR:
        path_to_local_file_uri(
            tempfile.mkdtemp(dir=local_file_uri_to_path(root_artifact_uri))),
    }
    with mock.patch.dict(os.environ, env):
        cmd = [
            "python", "-c",
            'from mlflow.server import app; app.run("{hostname}", {port})'.
            format(hostname=LOCALHOST, port=server_port)
        ]
        process = Popen(cmd)

    _await_server_up_or_die(server_port)
    url = "http://{hostname}:{port}".format(hostname=LOCALHOST,
                                            port=server_port)
    print("Launching tracking server against backend URI %s. Server URL: %s" %
          (backend_uri, url))
    return url, process
示例#2
0
def test_model_log(h2o_iris_model):
    h2o_model = h2o_iris_model.model
    old_uri = kiwi.get_tracking_uri()
    # should_start_run tests whether or not calling log_model() automatically starts a run.
    for should_start_run in [False, True]:
        with TempDir(chdr=True, remove_on_exit=True):
            try:
                artifact_path = "gbm_model"
                kiwi.set_tracking_uri("test")
                if should_start_run:
                    kiwi.start_run()
                kiwi.h2o.log_model(h2o_model=h2o_model,
                                   artifact_path=artifact_path)
                model_uri = "runs:/{run_id}/{artifact_path}".format(
                    run_id=kiwi.active_run().info.run_id,
                    artifact_path=artifact_path)

                # Load model
                h2o_model_loaded = kiwi.h2o.load_model(model_uri=model_uri)
                assert all(
                    h2o_model_loaded.predict(h2o_iris_model.inference_data).
                    as_data_frame() == h2o_model.predict(
                        h2o_iris_model.inference_data).as_data_frame())
            finally:
                kiwi.end_run()
                kiwi.set_tracking_uri(old_uri)
示例#3
0
def test_sparkml_estimator_model_log(tmpdir, spark_model_estimator):
    # Print the coefficients and intercept for multinomial logistic regression
    old_tracking_uri = kiwi.get_tracking_uri()
    cnt = 0
    # should_start_run tests whether or not calling log_model() automatically starts a run.
    for should_start_run in [False, True]:
        for dfs_tmp_dir in [None, os.path.join(str(tmpdir), "test")]:
            print("should_start_run =", should_start_run, "dfs_tmp_dir =", dfs_tmp_dir)
            try:
                tracking_dir = os.path.abspath(str(tmpdir.join("mlruns")))
                kiwi.set_tracking_uri("file://%s" % tracking_dir)
                if should_start_run:
                    kiwi.start_run()
                artifact_path = "model%d" % cnt
                cnt += 1
                sparkm.log_model(
                    artifact_path=artifact_path,
                    spark_model=spark_model_estimator.model,
                    dfs_tmpdir=dfs_tmp_dir)
                model_uri = "runs:/{run_id}/{artifact_path}".format(
                    run_id=kiwi.active_run().info.run_id,
                    artifact_path=artifact_path)

                # test reloaded model
                reloaded_model = sparkm.load_model(model_uri=model_uri, dfs_tmpdir=dfs_tmp_dir)
                preds_df = reloaded_model.transform(spark_model_estimator.spark_df)
                preds = [x.prediction for x in preds_df.select("prediction").collect()]
                assert spark_model_estimator.predictions == preds
            finally:
                kiwi.end_run()
                kiwi.set_tracking_uri(old_tracking_uri)
                x = dfs_tmp_dir or sparkm.DFS_TMP
                shutil.rmtree(x)
                shutil.rmtree(tracking_dir)
示例#4
0
文件: conftest.py 项目: iPieter/kiwi
def tracking_uri_mock(tmpdir, request):
    try:
        if 'notrackingurimock' not in request.keywords:
            tracking_uri = path_to_local_sqlite_uri(
                os.path.join(tmpdir.strpath, 'mlruns'))
            kiwi.set_tracking_uri(tracking_uri)
            os.environ["MLFLOW_TRACKING_URI"] = tracking_uri
        yield tmpdir
    finally:
        kiwi.set_tracking_uri(None)
        if 'notrackingurimock' not in request.keywords:
            del os.environ["MLFLOW_TRACKING_URI"]
示例#5
0
def test_docker_project_tracking_uri_propagation(ProfileConfigProvider, tmpdir,
                                                 tracking_uri,
                                                 expected_command_segment,
                                                 docker_example_base_image):  # pylint: disable=unused-argument
    mock_provider = mock.MagicMock()
    mock_provider.get_config.return_value = \
        DatabricksConfig("host", "user", "pass", None, insecure=True)
    ProfileConfigProvider.return_value = mock_provider
    # Create and mock local tracking directory
    local_tracking_dir = os.path.join(tmpdir.strpath, "mlruns")
    if tracking_uri is None:
        tracking_uri = local_tracking_dir
    old_uri = kiwi.get_tracking_uri()
    try:
        kiwi.set_tracking_uri(tracking_uri)
        with mock.patch("mlflow.tracking._tracking_service.utils._get_store"
                        ) as _get_store_mock:
            _get_store_mock.return_value = file_store.FileStore(
                local_tracking_dir)
            kiwi.projects.run(
                TEST_DOCKER_PROJECT_DIR,
                experiment_id=file_store.FileStore.DEFAULT_EXPERIMENT_ID)
    finally:
        kiwi.set_tracking_uri(old_uri)
示例#6
0
def test_get_tracking_uri_for_run():
    kiwi.set_tracking_uri("http://some-uri")
    assert databricks._get_tracking_uri_for_run() == "http://some-uri"
    kiwi.set_tracking_uri("databricks://profile")
    assert databricks._get_tracking_uri_for_run() == "databricks"
    kiwi.set_tracking_uri(None)
    with mock.patch.dict(
            os.environ,
        {kiwi.tracking._TRACKING_URI_ENV_VAR: "http://some-uri"}):
        assert kiwi.tracking._tracking_service.utils.get_tracking_uri(
        ) == "http://some-uri"
示例#7
0
def http_tracking_uri_mock():
    kiwi.set_tracking_uri("http://some-cool-uri")
    yield
    kiwi.set_tracking_uri(None)