Exemplo n.º 1
0
def _before_run_validations(tracking_uri, cluster_spec):
    """Validations to perform before running a project on Databricks."""
    _check_databricks_auth_available()
    if cluster_spec is None:
        raise ExecutionException("Cluster spec must be provided when launching MLflow project runs "
                                 "on Databricks.")
    if tracking.is_local_uri(tracking_uri):
        raise ExecutionException(
            "When running on Databricks, the MLflow tracking URI must be set to a remote URI "
            "accessible to both the current client and code running on Databricks. Got local "
            "tracking URI %s." % tracking_uri)
Exemplo n.º 2
0
 def _validate_parameters(self, user_parameters):
     from mlflow.projects import ExecutionException
     missing_params = []
     for name in self.parameters:
         if name not in user_parameters and self.parameters[name].default is None:
             missing_params.append(name)
     if len(missing_params) == 1:
         raise ExecutionException(
             "No value given for missing parameter: '%s'" % missing_params[0])
     elif len(missing_params) > 1:
         raise ExecutionException(
             "No value given for missing parameters: %s" %
             ", ".join(["'%s'" % name for name in missing_params]))
Exemplo n.º 3
0
def run_databricks(uri, entry_point, version, parameters, experiment_id,
                   cluster_spec, git_username, git_password):
    """
    Runs the project at the specified URI on Databricks, returning a `SubmittedRun` that can be
    used to query the run's status or wait for the resulting Databricks Job run to terminate.
    """
    _check_databricks_auth_available()
    if cluster_spec is None:
        raise ExecutionException(
            "Cluster spec must be provided when launching MLflow project runs "
            "on Databricks.")

    # Fetch the project into work_dir & validate parameters
    work_dir = _fetch_project(uri=uri,
                              use_temp_cwd=True,
                              version=version,
                              git_username=git_username,
                              git_password=git_password)
    project = _load_project(work_dir)
    project.get_entry_point(entry_point)._validate_parameters(parameters)
    # Upload the project to DBFS, get the URI of the project
    dbfs_project_uri = _upload_project_to_dbfs(work_dir, experiment_id)

    # Create run object with remote tracking server. Get the git commit from the working directory,
    # etc.
    tracking_uri = tracking.get_tracking_uri()
    remote_run = _create_databricks_run(
        tracking_uri=tracking_uri,
        experiment_id=experiment_id,
        source_name=_expand_uri(uri),
        source_version=tracking._get_git_commit(work_dir),
        entry_point_name=entry_point)
    # Set up environment variables for remote execution
    env_vars = {}
    if experiment_id is not None:
        eprint("=== Using experiment ID %s ===" % experiment_id)
        env_vars[tracking._EXPERIMENT_ID_ENV_VAR] = experiment_id
    if remote_run is not None:
        env_vars[tracking._TRACKING_URI_ENV_VAR] = tracking.get_tracking_uri()
        env_vars[tracking._RUN_ID_ENV_VAR] = remote_run.run_info.run_uuid
    eprint("=== Running entry point %s of project %s on Databricks. ===" %
           (entry_point, uri))
    # Launch run on Databricks
    with open(cluster_spec, 'r') as handle:
        try:
            cluster_spec = json.load(handle)
        except ValueError:
            eprint(
                "Error when attempting to load and parse JSON cluster spec from file "
                "%s. " % cluster_spec)
            raise
    fuse_dst_dir = os.path.join(
        "/dbfs/",
        _parse_dbfs_uri_path(dbfs_project_uri).lstrip("/"))
    final_run_id = remote_run.run_info.run_uuid if remote_run else None
    command = _get_databricks_run_cmd(fuse_dst_dir, final_run_id, entry_point,
                                      parameters)
    db_run_id = _run_shell_command_job(uri, command, env_vars, cluster_spec)
    run_id = remote_run.run_info.run_uuid if remote_run else None
    return DatabricksSubmittedRun(db_run_id, run_id)
Exemplo n.º 4
0
def _check_databricks_auth_available():
    try:
        process.exec_cmd(["databricks", "--version"])
    except process.ShellCommandException:
        raise ExecutionException(
            "Could not find Databricks CLI on PATH. Please install and configure the Databricks "
            "CLI as described in https://github.com/databricks/databricks-cli")
    # Verify that we can get Databricks auth
    rest_utils.get_databricks_hostname_and_auth()
Exemplo n.º 5
0
 def _compute_path_value(self, user_param_value, storage_dir):
     from mlflow.projects import ExecutionException
     if not data.is_uri(user_param_value):
         if not os.path.exists(user_param_value):
             raise ExecutionException("Got value %s for parameter %s, but no such file or "
                                      "directory was found." % (user_param_value, self.name))
         return os.path.abspath(user_param_value)
     basename = os.path.basename(user_param_value)
     dest_path = os.path.join(storage_dir, basename)
     if dest_path != user_param_value:
         data.download_uri(uri=user_param_value, output_path=dest_path)
     return os.path.abspath(dest_path)
Exemplo n.º 6
0
 def get_entry_point(self, entry_point):
     from mlflow.projects import ExecutionException
     if entry_point in self.entry_points:
         return self.entry_points[entry_point]
     _, file_extension = os.path.splitext(entry_point)
     ext_to_cmd = {".py": "python", ".sh": os.environ.get("SHELL", "bash")}
     if file_extension in ext_to_cmd:
         command = "%s %s" % (ext_to_cmd[file_extension], shlex_quote(entry_point))
         return EntryPoint(name=entry_point, parameters={}, command=command)
     raise ExecutionException("Could not find {0} among entry points {1} or interpret {0} as a "
                              "runnable script. Supported script file extensions: "
                              "{2}".format(entry_point, list(self.entry_points.keys()),
                                           list(ext_to_cmd.keys())))
Exemplo n.º 7
0
def _dbfs_path_exists(dbfs_uri):
    """
    Returns True if the passed-in path exists in DBFS for the workspace corresponding to the
    default Databricks CLI profile.
    """
    dbfs_path = _parse_dbfs_uri_path(dbfs_uri)
    json_response_obj = rest_utils.databricks_api_request(
        endpoint="dbfs/get-status", method="GET", json={"path": dbfs_path})
    # If request fails with a RESOURCE_DOES_NOT_EXIST error, the file does not exist on DBFS
    error_code_field = "error_code"
    if error_code_field in json_response_obj:
        if json_response_obj[error_code_field] == "RESOURCE_DOES_NOT_EXIST":
            return False
        raise ExecutionException("Got unexpected error response when checking whether file %s "
                                 "exists in DBFS: %s" % json_response_obj)
    return True
Exemplo n.º 8
0
 def _compute_uri_value(self, user_param_value):
     from mlflow.projects import ExecutionException
     if not data.is_uri(user_param_value):
         raise ExecutionException("Expected URI for parameter %s but got "
                                  "%s" % (self.name, user_param_value))
     return user_param_value