예제 #1
0
def run_databricks(uri, entry_point, version, parameters, experiment_id,
                   cluster_spec, git_username, git_password):
    """
    Runs the project at the specified URI on Databricks, returning a `SubmittedRun` that can be
    used to query the run's status or wait for the resulting Databricks Job run to terminate.
    """
    _check_databricks_auth_available()
    if cluster_spec is None:
        raise ExecutionException(
            "Cluster spec must be provided when launching MLflow project runs "
            "on Databricks.")

    # Fetch the project into work_dir & validate parameters
    work_dir = _fetch_project(uri=uri,
                              use_temp_cwd=True,
                              version=version,
                              git_username=git_username,
                              git_password=git_password)
    project = _load_project(work_dir)
    project.get_entry_point(entry_point)._validate_parameters(parameters)
    # Upload the project to DBFS, get the URI of the project
    dbfs_project_uri = _upload_project_to_dbfs(work_dir, experiment_id)

    # Create run object with remote tracking server. Get the git commit from the working directory,
    # etc.
    tracking_uri = tracking.get_tracking_uri()
    remote_run = _create_databricks_run(
        tracking_uri=tracking_uri,
        experiment_id=experiment_id,
        source_name=_expand_uri(uri),
        source_version=tracking._get_git_commit(work_dir),
        entry_point_name=entry_point)
    # Set up environment variables for remote execution
    env_vars = {}
    if experiment_id is not None:
        eprint("=== Using experiment ID %s ===" % experiment_id)
        env_vars[tracking._EXPERIMENT_ID_ENV_VAR] = experiment_id
    if remote_run is not None:
        env_vars[tracking._TRACKING_URI_ENV_VAR] = tracking.get_tracking_uri()
        env_vars[tracking._RUN_ID_ENV_VAR] = remote_run.run_info.run_uuid
    eprint("=== Running entry point %s of project %s on Databricks. ===" %
           (entry_point, uri))
    # Launch run on Databricks
    with open(cluster_spec, 'r') as handle:
        try:
            cluster_spec = json.load(handle)
        except ValueError:
            eprint(
                "Error when attempting to load and parse JSON cluster spec from file "
                "%s. " % cluster_spec)
            raise
    fuse_dst_dir = os.path.join(
        "/dbfs/",
        _parse_dbfs_uri_path(dbfs_project_uri).lstrip("/"))
    final_run_id = remote_run.run_info.run_uuid if remote_run else None
    command = _get_databricks_run_cmd(fuse_dst_dir, final_run_id, entry_point,
                                      parameters)
    db_run_id = _run_shell_command_job(uri, command, env_vars, cluster_spec)
    run_id = remote_run.run_info.run_uuid if remote_run else None
    return DatabricksSubmittedRun(db_run_id, run_id)
예제 #2
0
def run_databricks(uri, entry_point, version, parameters, experiment_id, cluster_spec,
                   git_username, git_password):
    """
    Runs the project at the specified URI on Databricks, returning a `SubmittedRun` that can be
    used to query the run's status or wait for the resulting Databricks Job run to terminate.
    """
    tracking_uri = tracking.get_tracking_uri()
    _before_run_validations(tracking_uri, cluster_spec)
    work_dir = _fetch_and_clean_project(
        uri=uri, version=version, git_username=git_username, git_password=git_password)
    project = _load_project(work_dir)
    project.get_entry_point(entry_point)._validate_parameters(parameters)
    dbfs_project_uri = _upload_project_to_dbfs(work_dir, experiment_id)
    remote_run = tracking._create_run(
        experiment_id=experiment_id, source_name=_expand_uri(uri),
        source_version=tracking._get_git_commit(work_dir), entry_point_name=entry_point,
        source_type=SourceType.PROJECT)
    env_vars = {
         tracking._TRACKING_URI_ENV_VAR: tracking_uri,
         tracking._EXPERIMENT_ID_ENV_VAR: experiment_id,
    }
    run_id = remote_run.run_info.run_uuid
    eprint("=== Running entry point %s of project %s on Databricks. ===" % (entry_point, uri))
    # Launch run on Databricks
    with open(cluster_spec, 'r') as handle:
        try:
            cluster_spec = json.load(handle)
        except ValueError:
            eprint("Error when attempting to load and parse JSON cluster spec from file "
                   "%s. " % cluster_spec)
            raise
    fuse_dst_dir = os.path.join("/dbfs/", _parse_dbfs_uri_path(dbfs_project_uri).lstrip("/"))
    command = _get_databricks_run_cmd(fuse_dst_dir, run_id, entry_point, parameters)
    db_run_id = _run_shell_command_job(uri, command, env_vars, cluster_spec)
    return DatabricksSubmittedRun(db_run_id, run_id)