Ejemplo n.º 1
0
def get_databricks_host_creds(server_uri=None):
    """
    Reads in configuration necessary to make HTTP requests to a Databricks server. This
    uses the Databricks CLI's ConfigProvider interface to load the DatabricksConfig object.
    If no Databricks CLI profile is found corresponding to the server URI, this function
    will attempt to retrieve these credentials from the Databricks Secret Manager. For that to work,
    the server URI will need to be of the following format: "databricks://scope:prefix". In the
    Databricks Secret Manager, we will query for a secret in the scope "<scope>" for secrets with
    keys of the form "<prefix>-host" and "<prefix>-token". Note that this prefix *cannot* be empty
    if trying to authenticate with this method. If found, those host credentials will be used. This
    method will throw an exception if sufficient auth cannot be found.

    :param server_uri: A URI that specifies the Databricks profile you want to use for making
    requests.
    :return: :py:class:`mlflow.rest_utils.MlflowHostCreds` which includes the hostname and
        authentication information necessary to talk to the Databricks server.
    """
    profile, path = get_db_info_from_uri(server_uri)
    if not hasattr(provider, "get_config"):
        _logger.warning(
            "Support for databricks-cli<0.8.0 is deprecated and will be removed"
            " in a future version.")
        config = provider.get_config_for_profile(profile)
    elif profile:
        config = provider.ProfileConfigProvider(profile).get_config()
    else:
        config = provider.get_config()
    # if a path is specified, that implies a Databricks tracking URI of the form:
    # databricks://profile-name/path-specifier
    if (not config or not config.host) and path:
        dbutils = _get_dbutils()
        if dbutils:
            # Prefix differentiates users and is provided as path information in the URI
            key_prefix = path
            host = dbutils.secrets.get(scope=profile, key=key_prefix + "-host")
            token = dbutils.secrets.get(scope=profile,
                                        key=key_prefix + "-token")
            if host and token:
                config = provider.DatabricksConfig.from_token(host=host,
                                                              token=token,
                                                              insecure=False)
    if not config or not config.host:
        _fail_malformed_databricks_auth(profile)

    insecure = hasattr(config, "insecure") and config.insecure

    if config.username is not None and config.password is not None:
        return MlflowHostCreds(
            config.host,
            username=config.username,
            password=config.password,
            ignore_tls_verification=insecure,
        )
    elif config.token:
        return MlflowHostCreds(config.host,
                               token=config.token,
                               ignore_tls_verification=insecure)
    _fail_malformed_databricks_auth(profile)
Ejemplo n.º 2
0
def get_workspace_info_from_databricks_secrets(tracking_uri):
    profile, key_prefix = get_db_info_from_uri(tracking_uri)
    if key_prefix:
        dbutils = _get_dbutils()
        if dbutils:
            workspace_id = dbutils.secrets.get(scope=profile, key=key_prefix + "-workspace-id")
            workspace_host = dbutils.secrets.get(scope=profile, key=key_prefix + "-host")
            return workspace_host, workspace_id
    return None, None
Ejemplo n.º 3
0
    def get_db_store(self):
        try:
            tracking_uri = mlflow.get_tracking_uri()
        except ImportError:
            logger.warning(VERSION_WARNING.format("mlflow.get_tracking_uri"))
            tracking_uri = mlflow.tracking.get_tracking_uri()

        from mlflow.utils.databricks_utils import get_databricks_host_creds
        try:
            # If get_db_info_from_uri exists, it means mlflow 1.10 or above
            from mlflow.utils.uri import get_db_info_from_uri
            profile, path = get_db_info_from_uri("databricks")

            return RestStore(lambda: get_databricks_host_creds(tracking_uri))
        except ImportError:
            try:
                from mlflow.utils.uri import get_db_profile_from_uri
            except ImportError:
                logger.warning(VERSION_WARNING.format("from mlflow"))
                from mlflow.tracking.utils import get_db_profile_from_uri

            profile = get_db_profile_from_uri("databricks")
            logger.info("tracking uri: {} and profile: {}".format(tracking_uri, profile))
            return RestStore(lambda: get_databricks_host_creds(profile))
Ejemplo n.º 4
0
def test_get_db_info_from_uri_errors_invalid_profile(server_uri):
    with pytest.raises(MlflowException,
                       match="Unsupported Databricks profile"):
        get_db_info_from_uri(server_uri)
Ejemplo n.º 5
0
def test_get_db_info_from_uri_errors_no_netloc(server_uri):
    with pytest.raises(MlflowException, match="URI is formatted incorrectly"):
        get_db_info_from_uri(server_uri)
Ejemplo n.º 6
0
def test_get_db_info_from_uri(server_uri, result):
    assert get_db_info_from_uri(server_uri) == result
Ejemplo n.º 7
0
def test_get_db_info_from_uri_errors(server_uri):
    with pytest.raises(MlflowException):
        get_db_info_from_uri(server_uri)