Esempio n. 1
0
def run_databricks(uri, entry_point, version, parameters, experiment_id, cluster_spec,
                   git_username, git_password):
    """
    Runs the project at the specified URI on Databricks, returning a `SubmittedRun` that can be
    used to query the run's status or wait for the resulting Databricks Job run to terminate.
    """
    tracking_uri = tracking.get_tracking_uri()
    _before_run_validations(tracking_uri, cluster_spec)
    work_dir = _fetch_and_clean_project(
        uri=uri, version=version, git_username=git_username, git_password=git_password)
    project = _load_project(work_dir)
    project.get_entry_point(entry_point)._validate_parameters(parameters)
    dbfs_project_uri = _upload_project_to_dbfs(work_dir, experiment_id)
    remote_run = tracking._create_run(
        experiment_id=experiment_id, source_name=_expand_uri(uri),
        source_version=tracking._get_git_commit(work_dir), entry_point_name=entry_point,
        source_type=SourceType.PROJECT)
    env_vars = {
         tracking._TRACKING_URI_ENV_VAR: tracking_uri,
         tracking._EXPERIMENT_ID_ENV_VAR: experiment_id,
    }
    run_id = remote_run.run_info.run_uuid
    eprint("=== Running entry point %s of project %s on Databricks. ===" % (entry_point, uri))
    # Launch run on Databricks
    with open(cluster_spec, 'r') as handle:
        try:
            cluster_spec = json.load(handle)
        except ValueError:
            eprint("Error when attempting to load and parse JSON cluster spec from file "
                   "%s. " % cluster_spec)
            raise
    fuse_dst_dir = os.path.join("/dbfs/", _parse_dbfs_uri_path(dbfs_project_uri).lstrip("/"))
    command = _get_databricks_run_cmd(fuse_dst_dir, run_id, entry_point, parameters)
    db_run_id = _run_shell_command_job(uri, command, env_vars, cluster_spec)
    return DatabricksSubmittedRun(db_run_id, run_id)
Esempio n. 2
0
def _create_run(uri, experiment_id, work_dir, entry_point, parameters):
    """
    Create a ``Run`` against the current MLflow tracking server, logging metadata (e.g. the URI,
    entry point, and parameters of the project) about the run. Return an ``ActiveRun`` that can be
    used to report additional data about the run (metrics/params) to the tracking server.
    """
    active_run = tracking._create_run(
        experiment_id=experiment_id, source_name=_expand_uri(uri),
        source_version=tracking._get_git_commit(work_dir), entry_point_name=entry_point,
        source_type=SourceType.PROJECT)
    if parameters is not None:
        for key, value in parameters.items():
            active_run.log_param(Param(key, value))
    return active_run
Esempio n. 3
0
def _create_databricks_run(tracking_uri, experiment_id, source_name,
                           source_version, entry_point_name):
    """
    Make an API request to the specified tracking server to create a new run with the specified
    attributes. Return an ``ActiveRun`` that can be used to query the tracking server for the run's
    status or log metrics/params for the run.
    """
    if tracking.is_local_uri(tracking_uri):
        eprint(
            "WARNING: MLflow tracking URI is set to a local URI (%s), so results from "
            "Databricks will not be logged permanently." % tracking_uri)
    return tracking._create_run(experiment_id=experiment_id,
                                source_name=source_name,
                                source_version=source_version,
                                entry_point_name=entry_point_name,
                                source_type=SourceType.PROJECT)
Esempio n. 4
0
def _run_project(project, entry_point, work_dir, parameters, use_conda,
                 storage_dir, experiment_id, block):
    """Locally run a project that has been checked out in `work_dir`."""
    storage_dir_for_run = _get_storage_dir(storage_dir)
    eprint(
        "=== Created directory %s for downloading remote URIs passed to arguments of "
        "type 'path' ===" % storage_dir_for_run)
    # Try to build the command first in case the user mis-specified parameters
    run_project_command = project.get_entry_point(entry_point)\
        .compute_command(parameters, storage_dir_for_run)
    commands = []
    if use_conda:
        conda_env_path = os.path.abspath(
            os.path.join(work_dir, project.conda_env))
        _maybe_create_conda_env(conda_env_path)
        commands.append("source activate %s" %
                        _get_conda_env_name(conda_env_path))

    # Create a new run and log every provided parameter into it.
    active_run = tracking._create_run(
        experiment_id=experiment_id,
        source_name=project.uri,
        source_version=tracking._get_git_commit(work_dir),
        entry_point_name=entry_point,
        source_type=SourceType.PROJECT)
    if parameters is not None:
        for key, value in parameters.items():
            active_run.log_param(Param(key, value))
    # Add the run id into a magic environment variable that the subprocess will read,
    # causing it to reuse the run.
    env_map = {
        tracking._RUN_ID_ENV_VAR: active_run.run_info.run_uuid,
        tracking._TRACKING_URI_ENV_VAR: tracking.get_tracking_uri(),
        tracking._EXPERIMENT_ID_ENV_VAR: str(experiment_id),
    }

    commands.append(run_project_command)
    command = " && ".join(commands)
    eprint("=== Running command '%s' in run with ID '%s' === " %
           (command, active_run.run_info.run_uuid))

    return _launch_local_run(active_run,
                             command,
                             work_dir,
                             env_map,
                             stream_output=block)
Esempio n. 5
0
def _create_databricks_run(tracking_uri, experiment_id, source_name,
                           source_version, entry_point_name):
    """
    Makes an API request to the specified tracking server to create a new run with the specified
    attributes. Returns an `ActiveRun` that can be used to query the tracking server for the run's
    status or log metrics/params for the run.
    """
    if tracking.is_local_uri(tracking_uri):
        # TODO: we'll actually use the Databricks deployment's tracking URI here in the future
        eprint(
            "WARNING: MLflow tracking URI is set to a local URI (%s), so results from "
            "Databricks will not be logged permanently." % tracking_uri)
        return None
    else:
        # Assume non-local tracking URIs are accessible from Databricks (won't work for e.g.
        # localhost)
        return tracking._create_run(experiment_id=experiment_id,
                                    source_name=source_name,
                                    source_version=source_version,
                                    entry_point_name=entry_point_name,
                                    source_type=SourceType.PROJECT)