def _get_docker_tracking_cmd_and_envs(tracking_uri): cmds = [] env_vars = dict() local_path, container_tracking_uri = _get_local_uri_or_none(tracking_uri) if local_path is not None: cmds = ["-v", "%s:%s" % (local_path, _MLFLOW_DOCKER_TRACKING_DIR_PATH)] env_vars[tracking._TRACKING_URI_ENV_VAR] = container_tracking_uri if is_databricks_uri(tracking_uri): db_profile = get_db_profile_from_uri(tracking_uri) config = databricks_utils.get_databricks_host_creds(db_profile) # We set these via environment variables so that only the current profile is exposed, rather # than all profiles in ~/.databrickscfg; maybe better would be to mount the necessary # part of ~/.databrickscfg into the container env_vars[tracking._TRACKING_URI_ENV_VAR] = 'databricks' env_vars['DATABRICKS_HOST'] = config.host if config.username: env_vars['DATABRICKS_USERNAME'] = config.username if config.password: env_vars['DATABRICKS_PASSWORD'] = config.password if config.token: env_vars['DATABRICKS_TOKEN'] = config.token if config.ignore_tls_verification: env_vars['DATABRICKS_INSECURE'] = config.ignore_tls_verification return cmds, env_vars
def run_databricks(remote_run, uri, entry_point, work_dir, parameters, experiment_id, cluster_spec): """ Run the project at the specified URI on Databricks, returning a ``SubmittedRun`` that can be used to query the run's status or wait for the resulting Databricks Job run to terminate. """ profile = get_db_profile_from_uri(tracking.get_tracking_uri()) run_id = remote_run.info.run_id db_job_runner = DatabricksJobRunner(databricks_profile=profile) db_run_id = db_job_runner.run_databricks( uri, entry_point, work_dir, parameters, experiment_id, cluster_spec, run_id) submitted_run = DatabricksSubmittedRun(db_run_id, run_id, db_job_runner) submitted_run._print_description_and_log_tags() return submitted_run
def _get_databricks_env_vars(tracking_uri): if not mlflow.utils.uri.is_databricks_uri(tracking_uri): return {} db_profile = get_db_profile_from_uri(tracking_uri) config = databricks_utils.get_databricks_host_creds(db_profile) # We set these via environment variables so that only the current profile is exposed, rather # than all profiles in ~/.databrickscfg; maybe better would be to mount the necessary # part of ~/.databrickscfg into the container env_vars = {} env_vars[tracking._TRACKING_URI_ENV_VAR] = 'databricks' env_vars['DATABRICKS_HOST'] = config.host if config.username: env_vars['DATABRICKS_USERNAME'] = config.username if config.password: env_vars['DATABRICKS_PASSWORD'] = config.password if config.token: env_vars['DATABRICKS_TOKEN'] = config.token if config.ignore_tls_verification: env_vars['DATABRICKS_INSECURE'] = str(config.ignore_tls_verification) return env_vars
def get_db_store(self): try: tracking_uri = mlflow.get_tracking_uri() except ImportError: logger.warning(VERSION_WARNING.format("mlflow.get_tracking_uri")) tracking_uri = mlflow.tracking.get_tracking_uri() from mlflow.utils.databricks_utils import get_databricks_host_creds try: # If get_db_info_from_uri exists, it means mlflow 1.10 or above from mlflow.utils.uri import get_db_info_from_uri profile, path = get_db_info_from_uri("databricks") return RestStore(lambda: get_databricks_host_creds(tracking_uri)) except ImportError: try: from mlflow.utils.uri import get_db_profile_from_uri except ImportError: logger.warning(VERSION_WARNING.format("from mlflow")) from mlflow.tracking.utils import get_db_profile_from_uri profile = get_db_profile_from_uri("databricks") logger.info("tracking uri: {} and profile: {}".format(tracking_uri, profile)) return RestStore(lambda: get_databricks_host_creds(profile))
def test_get_db_profile_from_uri_casing(): assert get_db_profile_from_uri('databricks://aAbB') == 'aAbB'
def _call_endpoint(self, service, api, json_body): db_profile = get_db_profile_from_uri(mlflow.tracking.get_tracking_uri()) db_creds = get_databricks_host_creds(db_profile) endpoint, method = _SERVICE_AND_METHOD_TO_INFO[service][api] response_proto = api.Response() return call_endpoint(db_creds, endpoint, method, json_body, response_proto)
def _get_databricks_rest_store(store_uri, **_): profile = get_db_profile_from_uri(store_uri) return DatabricksRestStore(lambda: get_databricks_host_creds(profile))