示例#1
0
def test_get_store_file_store(tmpdir):
    env = {}
    with mock.patch.dict(os.environ, env):
        store = _get_store()
        assert isinstance(store, FileStore)
        assert store.root_directory == os.path.abspath("mlruns")

        # Make sure we look at the parameter...
        store = _get_store(tmpdir.strpath)
        assert isinstance(store, FileStore)
        assert store.root_directory == tmpdir
示例#2
0
def test_get_store_file_store_from_env(tmp_wkdir, uri):
    env = {_TRACKING_URI_ENV_VAR: uri}
    with mock.patch.dict(os.environ, env):
        store = _get_store()
        assert isinstance(store, FileStore)
        assert os.path.abspath(
            store.root_directory) == os.path.abspath("other/path")
示例#3
0
def test_get_store_file_store_from_arg(tmp_wkdir):
    env = {}
    with mock.patch.dict(os.environ, env):
        store = _get_store("other/path")
        assert isinstance(store, FileStore)
        assert os.path.abspath(
            store.root_directory) == os.path.abspath("other/path")
示例#4
0
def _get_host_creds_from_default_store():
    store = utils._get_store()
    if not isinstance(store, RestStore):
        raise MlflowException('Failed to get credentials for DBFS; they are read from the ' +
                              'Databricks CLI credentials or MLFLOW_TRACKING* environment ' +
                              'variables.')
    return store.get_host_creds
示例#5
0
def get_artifact_uri(run_id, artifact_path=None):
    """
    Get the absolute URI of the specified artifact in the specified run. If `path` is not specified,
    the artifact root URI of the specified run will be returned; calls to ``log_artifact``
    and ``log_artifacts`` write artifact(s) to subdirectories of the artifact root URI.

    :param run_id: The ID of the run for which to obtain an absolute artifact URI.
    :param artifact_path: The run-relative artifact path. For example,
                          ``path/to/artifact``. If unspecified, the artifact root URI for the
                          specified run will be returned.
    :return: An *absolute* URI referring to the specified artifact or the specified run's artifact
             root. For example, if an artifact path is provided and the specified run uses an
             S3-backed store, this may be a uri of the form
             ``s3://<bucket_name>/path/to/artifact/root/path/to/artifact``. If an artifact path
             is not provided and the specified run uses an S3-backed store, this may be a URI of
             the form ``s3://<bucket_name>/path/to/artifact/root``.
    """
    if not run_id:
        raise MlflowException(
            message=
            "A run_id must be specified in order to obtain an artifact uri!",
            error_code=INVALID_PARAMETER_VALUE)

    store = _get_store()
    run = store.get_run(run_id)
    # Maybe move this method to RunsArtifactRepository so the circular dependency is clearer.
    assert urllib.parse.urlparse(
        run.info.artifact_uri).scheme != "runs"  # avoid an infinite loop
    if artifact_path is None:
        return run.info.artifact_uri
    else:
        return posixpath.join(run.info.artifact_uri, artifact_path)
示例#6
0
def test_get_store_basic_rest_store():
    env = {_TRACKING_URI_ENV_VAR: "https://my-tracking-server:5050"}
    with mock.patch.dict(os.environ, env):
        store = _get_store()
        assert isinstance(store, RestStore)
        assert store.get_host_creds().host == "https://my-tracking-server:5050"
        assert store.get_host_creds().token is None
示例#7
0
def _get_model_log_dir(model_name, run_id):
    if not run_id:
        raise Exception(
            "Must specify a run_id to get logging directory for a model.")
    store = _get_store()
    run = store.get_run(run_id)
    artifact_repo = get_artifact_repository(run.info.artifact_uri)
    return artifact_repo.download_artifacts(model_name)
示例#8
0
def get_service(tracking_uri=None):
    """
    :param tracking_uri: Address of local or remote tracking server. If not provided,
      this will default to the store set by mlflow.tracking.set_tracking_uri. See
      https://mlflow.org/docs/latest/tracking.html#where-runs-get-recorded for more info.
    :return: mlflow.tracking.MLflowService"""
    store = _get_store(tracking_uri)
    return MLflowService(store)
示例#9
0
def test_get_store_rest_store_with_token():
    env = {
        _TRACKING_URI_ENV_VAR: "https://my-tracking-server:5050",
        _TRACKING_TOKEN_ENV_VAR: "my-token",
    }
    with mock.patch.dict(os.environ, env):
        store = _get_store()
        assert isinstance(store, RestStore)
        assert store.get_host_creds().token == "my-token"
示例#10
0
def test_get_store_rest_store_with_no_insecure():
    env = {
        _TRACKING_URI_ENV_VAR: "https://my-tracking-server:5050",
        _TRACKING_INSECURE_TLS_ENV_VAR: "false",
    }
    with mock.patch.dict(os.environ, env):
        store = _get_store()
        assert isinstance(store, RestStore)
        assert not store.get_host_creds().ignore_tls_verification

    # By default, should not ignore verification.
    env = {
        _TRACKING_URI_ENV_VAR: "https://my-tracking-server:5050",
    }
    with mock.patch.dict(os.environ, env):
        store = _get_store()
        assert isinstance(store, RestStore)
        assert not store.get_host_creds().ignore_tls_verification
示例#11
0
def test_get_store_rest_store_with_insecure(tmpdir):
    env = {
        _TRACKING_URI_ENV_VAR: "https://my-tracking-server:5050",
        _TRACKING_INSECURE_TLS_ENV_VAR: "true",
    }
    with mock.patch.dict(os.environ, env):
        store = _get_store()
        assert isinstance(store, RestStore)
        assert store.get_host_creds().ignore_tls_verification
示例#12
0
文件: client.py 项目: zahraa1/mlflow
 def __init__(self, tracking_uri=None):
     """
     :param tracking_uri: Address of local or remote tracking server. If not provided, defaults
                          to the service set by ``mlflow.tracking.set_tracking_uri``. See
                          `Where Runs Get Recorded <../tracking.html#where-runs-get-recorded>`_
                          for more info.
     """
     self.tracking_uri = tracking_uri or utils.get_tracking_uri()
     self.store = utils._get_store(self.tracking_uri)
示例#13
0
def test_get_store_databricks():
    env = {
        _TRACKING_URI_ENV_VAR: "databricks",
        'DATABRICKS_HOST': "https://my-tracking-server",
        'DATABRICKS_TOKEN': "abcdef",
    }
    with mock.patch.dict(os.environ, env):
        store = _get_store()
        assert isinstance(store, RestStore)
        assert store.get_host_creds().host == "https://my-tracking-server"
        assert store.get_host_creds().token == "abcdef"
示例#14
0
 def __init__(self, tracking_uri=None, user=None):
     """
     :param tracking_uri: Address of local or remote tracking server. If not provided, defaults
                          to the service set by ``mlflow.tracking.set_tracking_uri``. See
                          `Where Runs Get Recorded <../tracking.html#where-runs-get-recorded>`_
                          for more info.
     """
     self.tracking_uri = tracking_uri or utils.get_tracking_uri()
     self.user = user or os.environ.get('MLFLOW_RANGER_USER', 'mlflow')
     os.environ['MLFLOW_RANGER_USER'] = self.user
     self.store = utils._get_store(self.tracking_uri)
示例#15
0
文件: service.py 项目: kmader/mlflow
def get_service(tracking_uri=None):
    """
    Get the tracking service.

    :param tracking_uri: Address of local or remote tracking server. If not provided,
      this defaults to the service set by ``mlflow.tracking.set_tracking_uri``. See
      `Where Runs Get Recorded <../tracking.html#where-runs-get-recorded>`_ for more info.
    :return: :py:class:`mlflow.tracking.MLflowService`
    """
    store = _get_store(tracking_uri)
    return MLflowService(store)
示例#16
0
def test_get_store_databricks_profile():
    env = {
        _TRACKING_URI_ENV_VAR: "databricks://mycoolprofile",
    }
    # It's kind of annoying to setup a profile, and we're not really trying to test
    # that anyway, so just check if we raise a relevant exception.
    with mock.patch.dict(os.environ, env):
        store = _get_store()
        assert isinstance(store, RestStore)
        with pytest.raises(Exception) as e_info:
            store.get_host_creds()
        assert 'mycoolprofile' in str(e_info.value)
示例#17
0
def test_get_store_rest_store_with_password():
    env = {
        _TRACKING_URI_ENV_VAR: "https://my-tracking-server:5050",
        _TRACKING_USERNAME_ENV_VAR: "Bob",
        _TRACKING_PASSWORD_ENV_VAR: "Ross",
    }
    with mock.patch.dict(os.environ, env):
        store = _get_store()
        assert isinstance(store, RestStore)
        assert store.get_host_creds().host == "https://my-tracking-server:5050"
        assert store.get_host_creds().username == "Bob"
        assert store.get_host_creds().password == "Ross"
示例#18
0
def test_get_store_sqlalchemy_store(tmp_wkdir, db_type):
    patch_create_engine = mock.patch("sqlalchemy.create_engine")

    uri = "{}://hostname/database".format(db_type)
    env = {_TRACKING_URI_ENV_VAR: uri}
    with mock.patch.dict(os.environ,
                         env), patch_create_engine as mock_create_engine:
        store = _get_store()
        assert isinstance(store, SqlAlchemyStore)
        assert store.db_uri == uri
        assert store.artifact_root_uri == "./mlruns"

    mock_create_engine.assert_called_once_with(uri)
示例#19
0
def test_get_store_sqlalchemy_store(tmp_wkdir, db_type):
    patch_create_engine = mock.patch("sqlalchemy.create_engine")

    uri = "{}://hostname/database".format(db_type)
    env = {_TRACKING_URI_ENV_VAR: uri}
    with mock.patch.dict(os.environ, env), patch_create_engine as mock_create_engine,\
            mock.patch("mlflow.store.tracking.sqlalchemy_store.SqlAlchemyStore._verify_schema"), \
            mock.patch("mlflow.store.tracking.sqlalchemy_store.SqlAlchemyStore._initialize_tables"):
        store = _get_store()
        assert isinstance(store, SqlAlchemyStore)
        assert store.db_uri == uri
        assert store.artifact_root_uri == "./mlruns"

    mock_create_engine.assert_called_once_with(uri, pool_pre_ping=True)