def _fetch_project(uri, force_tempdir, version=None, git_username=None, git_password=None): """ Fetch a project into a local directory, returning the path to the local project directory. :param force_tempdir: If True, will fetch the project into a temporary directory. Otherwise, will fetch Git projects into a temporary directory but simply return the path of local projects (i.e. perform a no-op for local projects). """ parsed_uri, subdirectory = _parse_subdirectory(uri) use_temp_dst_dir = force_tempdir or not _is_local_uri(parsed_uri) dst_dir = tempfile.mkdtemp() if use_temp_dst_dir else parsed_uri if use_temp_dst_dir: eprint("=== Fetching project from %s into %s ===" % (uri, dst_dir)) if _is_local_uri(uri): if version is not None: raise ExecutionException( "Setting a version is only supported for Git project URIs") if use_temp_dst_dir: dir_util.copy_tree(src=parsed_uri, dst=dst_dir) else: assert _GIT_URI_REGEX.match( parsed_uri), "Non-local URI %s should be a Git URI" % parsed_uri _fetch_git_repo(parsed_uri, version, dst_dir, git_username, git_password) res = os.path.abspath(os.path.join(dst_dir, subdirectory)) if not os.path.exists(res): raise ExecutionException("Could not find subdirectory %s of %s" % (subdirectory, dst_dir)) return res
def _before_run_validations(tracking_uri, cluster_spec): """Validations to perform before running a project on Databricks.""" _check_databricks_auth_available() if cluster_spec is None: raise ExecutionException( "Cluster spec must be provided when launching MLflow project runs " "on Databricks.") if tracking.utils._is_local_uri(tracking_uri): raise ExecutionException( "When running on Databricks, the MLflow tracking URI must be set to a remote URI " "accessible to both the current client and code running on Databricks. Got local " "tracking URI %s." % tracking_uri)
def _validate_parameters(self, user_parameters): missing_params = [] for name in self.parameters: if (name not in user_parameters and self.parameters[name].default is None): missing_params.append(name) if len(missing_params) == 1: raise ExecutionException( "No value given for missing parameter: '%s'" % missing_params[0]) elif len(missing_params) > 1: raise ExecutionException( "No value given for missing parameters: %s" % ", ".join(["'%s'" % name for name in missing_params]))
def _get_or_create_conda_env(conda_env_path): """ Given a `Project`, creates a conda environment containing the project's dependencies if such a conda environment doesn't already exist. Returns the name of the conda environment. """ conda_path = _get_conda_bin_executable("conda") try: process.exec_cmd([conda_path, "--help"], throw_on_error=False) except EnvironmentError: raise ExecutionException( "Could not find Conda executable at {0}. " "Ensure Conda is installed as per the instructions " "at https://conda.io/docs/user-guide/install/index.html. You can " "also configure MLflow to look for a specific Conda executable " "by setting the {1} environment variable to the path of the Conda " "executable".format(conda_path, MLFLOW_CONDA_HOME)) (_, stdout, _) = process.exec_cmd([conda_path, "env", "list", "--json"]) env_names = [os.path.basename(env) for env in json.loads(stdout)['envs']] project_env_name = _get_conda_env_name(conda_env_path) if project_env_name not in env_names: eprint('=== Creating conda environment %s ===' % project_env_name) if conda_env_path: process.exec_cmd([ conda_path, "env", "create", "-n", project_env_name, "--file", conda_env_path ], stream_output=True) else: process.exec_cmd( [conda_path, "create", "-n", project_env_name, "python"], stream_output=True) return project_env_name
def _fetch_git_repo(uri, version, dst_dir, git_username, git_password): """ Clone the git repo at ``uri`` into ``dst_dir``, checking out commit ``version`` (or defaulting to the head commit of the repository's master branch if version is unspecified). If ``git_username`` and ``git_password`` are specified, uses them to authenticate while fetching the repo. Otherwise, assumes authentication parameters are specified by the environment, e.g. by a Git credential helper. """ # We defer importing git until the last moment, because the import requires that the git # executable is availble on the PATH, so we only want to fail if we actually need it. import git repo = git.Repo.init(dst_dir) origin = repo.create_remote("origin", uri) git_args = [git_username, git_password] if not (all(arg is not None for arg in git_args) or all(arg is None for arg in git_args)): raise ExecutionException( "Either both or neither of git_username and git_password must be " "specified.") if git_username: git_credentials = "url=%s\nusername=%s\npassword=%s" % ( uri, git_username, git_password) repo.git.config("--local", "credential.helper", "cache") process.exec_cmd(cmd=["git", "credential-cache", "store"], cwd=dst_dir, cmd_stdin=git_credentials) origin.fetch() if version is not None: repo.git.checkout(version) else: repo.create_head("master", origin.refs.master) repo.heads.master.checkout()
def load_project(directory): mlproject_path = os.path.join(directory, MLPROJECT_FILE_NAME) # TODO: Validate structure of YAML loaded from the file if os.path.exists(mlproject_path): with open(mlproject_path) as mlproject_file: yaml_obj = yaml.safe_load(mlproject_file.read()) else: yaml_obj = {} entry_points = {} for name, entry_point_yaml in yaml_obj.get("entry_points", {}).items(): parameters = entry_point_yaml.get("parameters", {}) command = entry_point_yaml.get("command") entry_points[name] = EntryPoint(name, parameters, command) conda_path = yaml_obj.get("conda_env") if conda_path: conda_env_path = os.path.join(directory, conda_path) if not os.path.exists(conda_env_path): raise ExecutionException( "Project specified conda environment file %s, but no such " "file was found." % conda_env_path) return Project(conda_env_path=conda_env_path, entry_points=entry_points) default_conda_path = os.path.join(directory, DEFAULT_CONDA_FILE_NAME) if os.path.exists(default_conda_path): return Project(conda_env_path=default_conda_path, entry_points=entry_points) return Project(conda_env_path=None, entry_points=entry_points)
def _check_databricks_auth_available(): try: process.exec_cmd(["databricks", "--version"]) except process.ShellCommandException: raise ExecutionException( "Could not find Databricks CLI on PATH. Please install and configure the Databricks " "CLI as described in https://github.com/databricks/databricks-cli") # Verify that we can get Databricks auth rest_utils.get_databricks_http_request_kwargs_or_fail()
def _parse_subdirectory(uri): # Parses a uri and returns the uri and subdirectory as separate values. # Uses '#' as a delimiter. subdirectory = '' parsed_uri = uri if '#' in uri: subdirectory = uri[uri.find('#')+1:] parsed_uri = uri[:uri.find('#')] if subdirectory and '.' in subdirectory: raise ExecutionException("'.' is not allowed in project subdirectory paths.") return parsed_uri, subdirectory
def _compute_path_value(self, user_param_value, storage_dir): if not data.is_uri(user_param_value): if not os.path.exists(user_param_value): raise ExecutionException( "Got value %s for parameter %s, but no such file or " "directory was found." % (user_param_value, self.name)) return os.path.abspath(user_param_value) basename = os.path.basename(user_param_value) dest_path = os.path.join(storage_dir, basename) if dest_path != user_param_value: data.download_uri(uri=user_param_value, output_path=dest_path) return os.path.abspath(dest_path)
def get_entry_point(self, entry_point): if entry_point in self._entry_points: return self._entry_points[entry_point] _, file_extension = os.path.splitext(entry_point) ext_to_cmd = {".py": "python", ".sh": os.environ.get("SHELL", "bash")} if file_extension in ext_to_cmd: command = "%s %s" % (ext_to_cmd[file_extension], shlex_quote(entry_point)) if type(command) not in six.string_types: command = command.encode("utf-8") return EntryPoint(name=entry_point, parameters={}, command=command) raise ExecutionException( "Could not find {0} among entry points {1} or interpret {0} as a " "runnable script. Supported script file extensions: " "{2}".format(entry_point, list(self._entry_points.keys()), list(ext_to_cmd.keys())))
def _dbfs_path_exists(dbfs_uri): """ Returns True if the passed-in path exists in DBFS for the workspace corresponding to the default Databricks CLI profile. """ dbfs_path = _parse_dbfs_uri_path(dbfs_uri) json_response_obj = rest_utils.databricks_api_request( endpoint="dbfs/get-status", method="GET", json={"path": dbfs_path}) # If request fails with a RESOURCE_DOES_NOT_EXIST error, the file does not exist on DBFS error_code_field = "error_code" if error_code_field in json_response_obj: if json_response_obj[error_code_field] == "RESOURCE_DOES_NOT_EXIST": return False raise ExecutionException( "Got unexpected error response when checking whether file %s " "exists in DBFS: %s" % json_response_obj) return True
def _wait_for(submitted_run_obj): """Wait on the passed-in submitted run, reporting its status to the tracking server.""" run_id = submitted_run_obj.run_id active_run = None # Note: there's a small chance we fail to report the run's status to the tracking server if # we're interrupted before we reach the try block below try: active_run = tracking.get_service().get_run(run_id) if run_id is not None else None if submitted_run_obj.wait(): eprint("=== Run (ID '%s') succeeded ===" % run_id) _maybe_set_run_terminated(active_run, "FINISHED") else: _maybe_set_run_terminated(active_run, "FAILED") raise ExecutionException("=== Run (ID '%s') failed ===" % run_id) except KeyboardInterrupt: eprint("=== Run (ID '%s') === interrupted, cancelling run ===" % run_id) submitted_run_obj.cancel() _maybe_set_run_terminated(active_run, "FAILED") raise
def _run(uri, entry_point="main", version=None, parameters=None, experiment_id=None, mode=None, cluster_spec=None, git_username=None, git_password=None, use_conda=True, storage_dir=None, block=True, run_id=None): """ Helper that delegates to the project-running method corresponding to the passed-in mode. Returns a ``SubmittedRun`` corresponding to the project run. """ exp_id = experiment_id or _get_experiment_id() parameters = parameters or {} work_dir = _fetch_project(uri=uri, force_tempdir=False, version=version, git_username=git_username, git_password=git_password) project = _project_spec.load_project(work_dir) project.get_entry_point(entry_point)._validate_parameters(parameters) if run_id: active_run = tracking.get_service().get_run(run_id) else: active_run = _create_run(uri, exp_id, work_dir, entry_point, parameters) if mode == "databricks": from mlflow.projects.databricks import run_databricks return run_databricks( remote_run=active_run, uri=uri, entry_point=entry_point, work_dir=work_dir, parameters=parameters, experiment_id=exp_id, cluster_spec=cluster_spec) elif mode == "local" or mode is None: # Synchronously create a conda environment (even though this may take some time) to avoid # failures due to multiple concurrent attempts to create the same conda env. conda_env_name = _get_or_create_conda_env(project.conda_env_path) if use_conda else None # In blocking mode, run the entry point command in blocking fashion, sending status updates # to the tracking server when finished. Note that the run state may not be persisted to the # tracking server if interrupted if block: command = _get_entry_point_command( project, entry_point, parameters, conda_env_name, storage_dir) return _run_entry_point(command, work_dir, exp_id, run_id=active_run.info.run_uuid) # Otherwise, invoke `mlflow run` in a subprocess return _invoke_mlflow_run_subprocess( work_dir=work_dir, entry_point=entry_point, parameters=parameters, experiment_id=exp_id, use_conda=use_conda, storage_dir=storage_dir, run_id=active_run.info.run_uuid) supported_modes = ["local", "databricks"] raise ExecutionException("Got unsupported execution mode %s. Supported " "values: %s" % (mode, supported_modes))
def test_execution_exception_string_repr(): exc = ExecutionException("Uh oh") assert str(exc) == "Uh oh"
def _compute_uri_value(self, user_param_value): if not data.is_uri(user_param_value): raise ExecutionException("Expected URI for parameter %s but got " "%s" % (self.name, user_param_value)) return user_param_value