Exemple #1
0
def _fetch_project(uri,
                   force_tempdir,
                   version=None,
                   git_username=None,
                   git_password=None):
    """
    Fetch a project into a local directory, returning the path to the local project directory.
    :param force_tempdir: If True, will fetch the project into a temporary directory. Otherwise,
                          will fetch Git projects into a temporary directory but simply return the
                          path of local projects (i.e. perform a no-op for local projects).
    """
    parsed_uri, subdirectory = _parse_subdirectory(uri)
    use_temp_dst_dir = force_tempdir or not _is_local_uri(parsed_uri)
    dst_dir = tempfile.mkdtemp() if use_temp_dst_dir else parsed_uri
    if use_temp_dst_dir:
        eprint("=== Fetching project from %s into %s ===" % (uri, dst_dir))
    if _is_local_uri(uri):
        if version is not None:
            raise ExecutionException(
                "Setting a version is only supported for Git project URIs")
        if use_temp_dst_dir:
            dir_util.copy_tree(src=parsed_uri, dst=dst_dir)
    else:
        assert _GIT_URI_REGEX.match(
            parsed_uri), "Non-local URI %s should be a Git URI" % parsed_uri
        _fetch_git_repo(parsed_uri, version, dst_dir, git_username,
                        git_password)
    res = os.path.abspath(os.path.join(dst_dir, subdirectory))
    if not os.path.exists(res):
        raise ExecutionException("Could not find subdirectory %s of %s" %
                                 (subdirectory, dst_dir))
    return res
Exemple #2
0
def _before_run_validations(tracking_uri, cluster_spec):
    """Validations to perform before running a project on Databricks."""
    _check_databricks_auth_available()
    if cluster_spec is None:
        raise ExecutionException(
            "Cluster spec must be provided when launching MLflow project runs "
            "on Databricks.")
    if tracking.utils._is_local_uri(tracking_uri):
        raise ExecutionException(
            "When running on Databricks, the MLflow tracking URI must be set to a remote URI "
            "accessible to both the current client and code running on Databricks. Got local "
            "tracking URI %s." % tracking_uri)
Exemple #3
0
 def _validate_parameters(self, user_parameters):
     missing_params = []
     for name in self.parameters:
         if (name not in user_parameters
                 and self.parameters[name].default is None):
             missing_params.append(name)
     if len(missing_params) == 1:
         raise ExecutionException(
             "No value given for missing parameter: '%s'" %
             missing_params[0])
     elif len(missing_params) > 1:
         raise ExecutionException(
             "No value given for missing parameters: %s" %
             ", ".join(["'%s'" % name for name in missing_params]))
Exemple #4
0
def _get_or_create_conda_env(conda_env_path):
    """
    Given a `Project`, creates a conda environment containing the project's dependencies if such a
    conda environment doesn't already exist. Returns the name of the conda environment.
    """
    conda_path = _get_conda_bin_executable("conda")
    try:
        process.exec_cmd([conda_path, "--help"], throw_on_error=False)
    except EnvironmentError:
        raise ExecutionException(
            "Could not find Conda executable at {0}. "
            "Ensure Conda is installed as per the instructions "
            "at https://conda.io/docs/user-guide/install/index.html. You can "
            "also configure MLflow to look for a specific Conda executable "
            "by setting the {1} environment variable to the path of the Conda "
            "executable".format(conda_path, MLFLOW_CONDA_HOME))
    (_, stdout, _) = process.exec_cmd([conda_path, "env", "list", "--json"])
    env_names = [os.path.basename(env) for env in json.loads(stdout)['envs']]
    project_env_name = _get_conda_env_name(conda_env_path)
    if project_env_name not in env_names:
        eprint('=== Creating conda environment %s ===' % project_env_name)
        if conda_env_path:
            process.exec_cmd([
                conda_path, "env", "create", "-n", project_env_name, "--file",
                conda_env_path
            ],
                             stream_output=True)
        else:
            process.exec_cmd(
                [conda_path, "create", "-n", project_env_name, "python"],
                stream_output=True)
    return project_env_name
Exemple #5
0
def _fetch_git_repo(uri, version, dst_dir, git_username, git_password):
    """
    Clone the git repo at ``uri`` into ``dst_dir``, checking out commit ``version`` (or defaulting
    to the head commit of the repository's master branch if version is unspecified).
    If ``git_username`` and ``git_password`` are specified, uses them to authenticate while fetching
    the repo. Otherwise, assumes authentication parameters are specified by the environment,
    e.g. by a Git credential helper.
    """
    # We defer importing git until the last moment, because the import requires that the git
    # executable is availble on the PATH, so we only want to fail if we actually need it.
    import git
    repo = git.Repo.init(dst_dir)
    origin = repo.create_remote("origin", uri)
    git_args = [git_username, git_password]
    if not (all(arg is not None
                for arg in git_args) or all(arg is None for arg in git_args)):
        raise ExecutionException(
            "Either both or neither of git_username and git_password must be "
            "specified.")
    if git_username:
        git_credentials = "url=%s\nusername=%s\npassword=%s" % (
            uri, git_username, git_password)
        repo.git.config("--local", "credential.helper", "cache")
        process.exec_cmd(cmd=["git", "credential-cache", "store"],
                         cwd=dst_dir,
                         cmd_stdin=git_credentials)
    origin.fetch()
    if version is not None:
        repo.git.checkout(version)
    else:
        repo.create_head("master", origin.refs.master)
        repo.heads.master.checkout()
Exemple #6
0
def load_project(directory):
    mlproject_path = os.path.join(directory, MLPROJECT_FILE_NAME)
    # TODO: Validate structure of YAML loaded from the file
    if os.path.exists(mlproject_path):
        with open(mlproject_path) as mlproject_file:
            yaml_obj = yaml.safe_load(mlproject_file.read())
    else:
        yaml_obj = {}
    entry_points = {}
    for name, entry_point_yaml in yaml_obj.get("entry_points", {}).items():
        parameters = entry_point_yaml.get("parameters", {})
        command = entry_point_yaml.get("command")
        entry_points[name] = EntryPoint(name, parameters, command)
    conda_path = yaml_obj.get("conda_env")
    if conda_path:
        conda_env_path = os.path.join(directory, conda_path)
        if not os.path.exists(conda_env_path):
            raise ExecutionException(
                "Project specified conda environment file %s, but no such "
                "file was found." % conda_env_path)
        return Project(conda_env_path=conda_env_path,
                       entry_points=entry_points)
    default_conda_path = os.path.join(directory, DEFAULT_CONDA_FILE_NAME)
    if os.path.exists(default_conda_path):
        return Project(conda_env_path=default_conda_path,
                       entry_points=entry_points)
    return Project(conda_env_path=None, entry_points=entry_points)
Exemple #7
0
def _check_databricks_auth_available():
    try:
        process.exec_cmd(["databricks", "--version"])
    except process.ShellCommandException:
        raise ExecutionException(
            "Could not find Databricks CLI on PATH. Please install and configure the Databricks "
            "CLI as described in https://github.com/databricks/databricks-cli")
    # Verify that we can get Databricks auth
    rest_utils.get_databricks_http_request_kwargs_or_fail()
Exemple #8
0
def _parse_subdirectory(uri):
    # Parses a uri and returns the uri and subdirectory as separate values.
    # Uses '#' as a delimiter.
    subdirectory = ''
    parsed_uri = uri
    if '#' in uri:
        subdirectory = uri[uri.find('#')+1:]
        parsed_uri = uri[:uri.find('#')]
    if subdirectory and '.' in subdirectory:
        raise ExecutionException("'.' is not allowed in project subdirectory paths.")
    return parsed_uri, subdirectory
Exemple #9
0
 def _compute_path_value(self, user_param_value, storage_dir):
     if not data.is_uri(user_param_value):
         if not os.path.exists(user_param_value):
             raise ExecutionException(
                 "Got value %s for parameter %s, but no such file or "
                 "directory was found." % (user_param_value, self.name))
         return os.path.abspath(user_param_value)
     basename = os.path.basename(user_param_value)
     dest_path = os.path.join(storage_dir, basename)
     if dest_path != user_param_value:
         data.download_uri(uri=user_param_value, output_path=dest_path)
     return os.path.abspath(dest_path)
Exemple #10
0
 def get_entry_point(self, entry_point):
     if entry_point in self._entry_points:
         return self._entry_points[entry_point]
     _, file_extension = os.path.splitext(entry_point)
     ext_to_cmd = {".py": "python", ".sh": os.environ.get("SHELL", "bash")}
     if file_extension in ext_to_cmd:
         command = "%s %s" % (ext_to_cmd[file_extension],
                              shlex_quote(entry_point))
         if type(command) not in six.string_types:
             command = command.encode("utf-8")
         return EntryPoint(name=entry_point, parameters={}, command=command)
     raise ExecutionException(
         "Could not find {0} among entry points {1} or interpret {0} as a "
         "runnable script. Supported script file extensions: "
         "{2}".format(entry_point, list(self._entry_points.keys()),
                      list(ext_to_cmd.keys())))
Exemple #11
0
def _dbfs_path_exists(dbfs_uri):
    """
    Returns True if the passed-in path exists in DBFS for the workspace corresponding to the
    default Databricks CLI profile.
    """
    dbfs_path = _parse_dbfs_uri_path(dbfs_uri)
    json_response_obj = rest_utils.databricks_api_request(
        endpoint="dbfs/get-status", method="GET", json={"path": dbfs_path})
    # If request fails with a RESOURCE_DOES_NOT_EXIST error, the file does not exist on DBFS
    error_code_field = "error_code"
    if error_code_field in json_response_obj:
        if json_response_obj[error_code_field] == "RESOURCE_DOES_NOT_EXIST":
            return False
        raise ExecutionException(
            "Got unexpected error response when checking whether file %s "
            "exists in DBFS: %s" % json_response_obj)
    return True
Exemple #12
0
def _wait_for(submitted_run_obj):
    """Wait on the passed-in submitted run, reporting its status to the tracking server."""
    run_id = submitted_run_obj.run_id
    active_run = None
    # Note: there's a small chance we fail to report the run's status to the tracking server if
    # we're interrupted before we reach the try block below
    try:
        active_run = tracking.get_service().get_run(run_id) if run_id is not None else None
        if submitted_run_obj.wait():
            eprint("=== Run (ID '%s') succeeded ===" % run_id)
            _maybe_set_run_terminated(active_run, "FINISHED")
        else:
            _maybe_set_run_terminated(active_run, "FAILED")
            raise ExecutionException("=== Run (ID '%s') failed ===" % run_id)
    except KeyboardInterrupt:
        eprint("=== Run (ID '%s') === interrupted, cancelling run ===" % run_id)
        submitted_run_obj.cancel()
        _maybe_set_run_terminated(active_run, "FAILED")
        raise
Exemple #13
0
def _run(uri, entry_point="main", version=None, parameters=None, experiment_id=None,
         mode=None, cluster_spec=None, git_username=None, git_password=None, use_conda=True,
         storage_dir=None, block=True, run_id=None):
    """
    Helper that delegates to the project-running method corresponding to the passed-in mode.
    Returns a ``SubmittedRun`` corresponding to the project run.
    """
    exp_id = experiment_id or _get_experiment_id()
    parameters = parameters or {}
    work_dir = _fetch_project(uri=uri, force_tempdir=False, version=version,
                              git_username=git_username, git_password=git_password)
    project = _project_spec.load_project(work_dir)
    project.get_entry_point(entry_point)._validate_parameters(parameters)
    if run_id:
        active_run = tracking.get_service().get_run(run_id)
    else:
        active_run = _create_run(uri, exp_id, work_dir, entry_point, parameters)

    if mode == "databricks":
        from mlflow.projects.databricks import run_databricks
        return run_databricks(
            remote_run=active_run,
            uri=uri, entry_point=entry_point, work_dir=work_dir, parameters=parameters,
            experiment_id=exp_id, cluster_spec=cluster_spec)
    elif mode == "local" or mode is None:
        # Synchronously create a conda environment (even though this may take some time) to avoid
        # failures due to multiple concurrent attempts to create the same conda env.
        conda_env_name = _get_or_create_conda_env(project.conda_env_path) if use_conda else None
        # In blocking mode, run the entry point command in blocking fashion, sending status updates
        # to the tracking server when finished. Note that the run state may not be persisted to the
        # tracking server if interrupted
        if block:
            command = _get_entry_point_command(
                project, entry_point, parameters, conda_env_name, storage_dir)
            return _run_entry_point(command, work_dir, exp_id, run_id=active_run.info.run_uuid)
        # Otherwise, invoke `mlflow run` in a subprocess
        return _invoke_mlflow_run_subprocess(
            work_dir=work_dir, entry_point=entry_point, parameters=parameters, experiment_id=exp_id,
            use_conda=use_conda, storage_dir=storage_dir, run_id=active_run.info.run_uuid)
    supported_modes = ["local", "databricks"]
    raise ExecutionException("Got unsupported execution mode %s. Supported "
                             "values: %s" % (mode, supported_modes))
Exemple #14
0
def test_execution_exception_string_repr():
    exc = ExecutionException("Uh oh")
    assert str(exc) == "Uh oh"
Exemple #15
0
 def _compute_uri_value(self, user_param_value):
     if not data.is_uri(user_param_value):
         raise ExecutionException("Expected URI for parameter %s but got "
                                  "%s" % (self.name, user_param_value))
     return user_param_value