Пример #1
0
def run_databricks(uri, entry_point, version, parameters, experiment_id,
                   cluster_spec, git_username, git_password):
    """
    Runs the project at the specified URI on Databricks, returning a `SubmittedRun` that can be
    used to query the run's status or wait for the resulting Databricks Job run to terminate.
    """
    _check_databricks_auth_available()
    if cluster_spec is None:
        raise ExecutionException(
            "Cluster spec must be provided when launching MLflow project runs "
            "on Databricks.")

    # Fetch the project into work_dir & validate parameters
    work_dir = _get_work_dir(uri, use_temp_cwd=True)
    _fetch_project(uri, version, work_dir, git_username, git_password)
    project = _load_project(work_dir, uri)
    project.get_entry_point(entry_point)._validate_parameters(parameters)
    # Upload the project to DBFS, get the URI of the project
    dbfs_project_uri = _upload_project_to_dbfs(work_dir, experiment_id)

    # Create run object with remote tracking server. Get the git commit from the working directory,
    # etc.
    tracking_uri = tracking.get_tracking_uri()
    remote_run = _create_databricks_run(
        tracking_uri=tracking_uri,
        experiment_id=experiment_id,
        source_name=_expand_uri(uri),
        source_version=tracking._get_git_commit(work_dir),
        entry_point_name=entry_point)
    # Set up environment variables for remote execution
    env_vars = {}
    if experiment_id is not None:
        eprint("=== Using experiment ID %s ===" % experiment_id)
        env_vars[tracking._EXPERIMENT_ID_ENV_VAR] = experiment_id
    if remote_run is not None:
        env_vars[tracking._TRACKING_URI_ENV_VAR] = tracking.get_tracking_uri()
        env_vars[tracking._RUN_ID_ENV_VAR] = remote_run.run_info.run_uuid
    eprint("=== Running entry point %s of project %s on Databricks. ===" %
           (entry_point, uri))
    # Launch run on Databricks
    with open(cluster_spec, 'r') as handle:
        try:
            cluster_spec = json.load(handle)
        except ValueError:
            eprint(
                "Error when attempting to load and parse JSON cluster spec from file "
                "%s. " % cluster_spec)
            raise
    fuse_dst_dir = os.path.join(
        "/dbfs/",
        _parse_dbfs_uri_path(dbfs_project_uri).lstrip("/"))
    command = _get_databricks_run_cmd(fuse_dst_dir, entry_point, parameters)
    db_run_id = _run_shell_command_job(uri, command, env_vars, cluster_spec)
    return SubmittedRun(remote_run, DatabricksPollableRun(db_run_id))
Пример #2
0
def _fetch_and_clean_project(uri, version=None, git_username=None, git_password=None):
    """
    Fetches the project at the passed-in URI & prepares it for upload to DBFS. Returns the path of
    the temporary directory into which the project was fetched.
    """
    work_dir = _fetch_project(
        uri=uri, force_tempdir=True, version=version, git_username=git_username,
        git_password=git_password)
    # Remove the mlruns directory from the fetched project to avoid cache-busting
    mlruns_dir = os.path.join(work_dir, "mlruns")
    if os.path.exists(mlruns_dir):
        shutil.rmtree(mlruns_dir)
    return work_dir