예제 #1
0
def launch(
    environment: str,
    job: str,
    trace: bool,
    kill_on_sigterm: bool,
    existing_runs: str,
    as_run_submit: bool,
    tags: List[str],
    parameters: List[str],
    parameters_raw: Optional[str],
):
    dbx_echo(f"Launching job {job} on environment {environment}")

    api_client = prepare_environment(environment)
    additional_tags = parse_multiple(tags)

    if parameters_raw:
        prepared_parameters = parameters_raw
    else:
        override_parameters = parse_multiple(parameters)
        prepared_parameters = sum([[k, v]
                                   for k, v in override_parameters.items()],
                                  [])

    filter_string = generate_filter_string(environment)

    run_info = _find_deployment_run(filter_string, additional_tags,
                                    as_run_submit, environment)

    deployment_run_id = run_info["run_id"]

    with mlflow.start_run(run_id=deployment_run_id) as deployment_run:

        with mlflow.start_run(nested=True):
            artifact_base_uri = deployment_run.info.artifact_uri

            if not as_run_submit:
                run_launcher = RunNowLauncher(job, api_client,
                                              artifact_base_uri, existing_runs,
                                              prepared_parameters)
            else:
                run_launcher = RunSubmitLauncher(job, api_client,
                                                 artifact_base_uri,
                                                 existing_runs,
                                                 prepared_parameters,
                                                 environment)

            run_data, job_id = run_launcher.launch()

            jobs_service = JobsService(api_client)
            run_info = jobs_service.get_run(run_data["run_id"])
            run_url = run_info.get("run_page_url")
            dbx_echo(f"Run URL: {run_url}")
            if trace:
                if kill_on_sigterm:
                    dbx_echo("Click Ctrl+C to stop the run")
                    try:
                        dbx_status = _trace_run(api_client, run_data)
                    except KeyboardInterrupt:
                        dbx_status = "CANCELLED"
                        dbx_echo("Cancelling the run gracefully")
                        _cancel_run(api_client, run_data)
                        dbx_echo("Run cancelled successfully")
                else:
                    dbx_status = _trace_run(api_client, run_data)

                if dbx_status == "ERROR":
                    raise Exception(
                        "Tracked run failed during execution. Please check Databricks UI for run logs"
                    )
                dbx_echo("Launch command finished")

            else:
                dbx_status = "NOT_TRACKED"
                dbx_echo(
                    "Run successfully launched in non-tracking mode. Please check Databricks UI for job status"
                )

            deployment_tags = {
                "job_id": job_id,
                "run_id": run_data.get("run_id"),
                "dbx_action_type": "launch",
                "dbx_status": dbx_status,
                "dbx_environment": environment,
            }

            mlflow.set_tags(deployment_tags)
예제 #2
0
def _get_run_status(api_client: ApiClient,
                    run_data: Dict[str, Any]) -> Dict[str, Any]:
    jobs_service = JobsService(api_client)
    run_status = jobs_service.get_run(run_data["run_id"])
    return run_status