Пример #1
0
def _get_docker_tracking_cmd_and_envs(tracking_uri):
    cmds = []
    env_vars = dict()

    local_path, container_tracking_uri = _get_local_uri_or_none(tracking_uri)
    if local_path is not None:
        cmds = ["-v", "%s:%s" % (local_path, _MLFLOW_DOCKER_TRACKING_DIR_PATH)]
        env_vars[tracking._TRACKING_URI_ENV_VAR] = container_tracking_uri
    if is_databricks_uri(tracking_uri):
        db_profile = get_db_profile_from_uri(tracking_uri)
        config = databricks_utils.get_databricks_host_creds(db_profile)
        # We set these via environment variables so that only the current profile is exposed, rather
        # than all profiles in ~/.databrickscfg; maybe better would be to mount the necessary
        # part of ~/.databrickscfg into the container
        env_vars[tracking._TRACKING_URI_ENV_VAR] = 'databricks'
        env_vars['DATABRICKS_HOST'] = config.host
        if config.username:
            env_vars['DATABRICKS_USERNAME'] = config.username
        if config.password:
            env_vars['DATABRICKS_PASSWORD'] = config.password
        if config.token:
            env_vars['DATABRICKS_TOKEN'] = config.token
        if config.ignore_tls_verification:
            env_vars['DATABRICKS_INSECURE'] = config.ignore_tls_verification
    return cmds, env_vars
Пример #2
0
def run_databricks(remote_run, uri, entry_point, work_dir, parameters, experiment_id, cluster_spec):
    """
    Run the project at the specified URI on Databricks, returning a ``SubmittedRun`` that can be
    used to query the run's status or wait for the resulting Databricks Job run to terminate.
    """
    profile = get_db_profile_from_uri(tracking.get_tracking_uri())
    run_id = remote_run.info.run_id
    db_job_runner = DatabricksJobRunner(databricks_profile=profile)
    db_run_id = db_job_runner.run_databricks(
        uri, entry_point, work_dir, parameters, experiment_id, cluster_spec, run_id)
    submitted_run = DatabricksSubmittedRun(db_run_id, run_id, db_job_runner)
    submitted_run._print_description_and_log_tags()
    return submitted_run
Пример #3
0
def _get_databricks_env_vars(tracking_uri):
    if not mlflow.utils.uri.is_databricks_uri(tracking_uri):
        return {}

    db_profile = get_db_profile_from_uri(tracking_uri)
    config = databricks_utils.get_databricks_host_creds(db_profile)
    # We set these via environment variables so that only the current profile is exposed, rather
    # than all profiles in ~/.databrickscfg; maybe better would be to mount the necessary
    # part of ~/.databrickscfg into the container
    env_vars = {}
    env_vars[tracking._TRACKING_URI_ENV_VAR] = 'databricks'
    env_vars['DATABRICKS_HOST'] = config.host
    if config.username:
        env_vars['DATABRICKS_USERNAME'] = config.username
    if config.password:
        env_vars['DATABRICKS_PASSWORD'] = config.password
    if config.token:
        env_vars['DATABRICKS_TOKEN'] = config.token
    if config.ignore_tls_verification:
        env_vars['DATABRICKS_INSECURE'] = str(config.ignore_tls_verification)
    return env_vars
Пример #4
0
    def get_db_store(self):
        try:
            tracking_uri = mlflow.get_tracking_uri()
        except ImportError:
            logger.warning(VERSION_WARNING.format("mlflow.get_tracking_uri"))
            tracking_uri = mlflow.tracking.get_tracking_uri()

        from mlflow.utils.databricks_utils import get_databricks_host_creds
        try:
            # If get_db_info_from_uri exists, it means mlflow 1.10 or above
            from mlflow.utils.uri import get_db_info_from_uri
            profile, path = get_db_info_from_uri("databricks")

            return RestStore(lambda: get_databricks_host_creds(tracking_uri))
        except ImportError:
            try:
                from mlflow.utils.uri import get_db_profile_from_uri
            except ImportError:
                logger.warning(VERSION_WARNING.format("from mlflow"))
                from mlflow.tracking.utils import get_db_profile_from_uri

            profile = get_db_profile_from_uri("databricks")
            logger.info("tracking uri: {} and profile: {}".format(tracking_uri, profile))
            return RestStore(lambda: get_databricks_host_creds(profile))
Пример #5
0
def test_get_db_profile_from_uri_casing():
    assert get_db_profile_from_uri('databricks://aAbB') == 'aAbB'
Пример #6
0
 def _call_endpoint(self, service, api, json_body):
     db_profile = get_db_profile_from_uri(mlflow.tracking.get_tracking_uri())
     db_creds = get_databricks_host_creds(db_profile)
     endpoint, method = _SERVICE_AND_METHOD_TO_INFO[service][api]
     response_proto = api.Response()
     return call_endpoint(db_creds, endpoint, method, json_body, response_proto)
Пример #7
0
def _get_databricks_rest_store(store_uri, **_):
    profile = get_db_profile_from_uri(store_uri)
    return DatabricksRestStore(lambda: get_databricks_host_creds(profile))