class DatabricksPySparkStepLauncher(StepLauncher): def __init__( self, run_config, databricks_host, databricks_token, secrets_to_env_variables, storage, staging_prefix, wait_for_logs, max_completion_wait_time_seconds, poll_interval_sec=5, local_pipeline_package_path=None, local_dagster_job_package_path=None, ): self.run_config = check.dict_param(run_config, "run_config") self.databricks_host = check.str_param(databricks_host, "databricks_host") self.databricks_token = check.str_param(databricks_token, "databricks_token") self.secrets = check.list_param(secrets_to_env_variables, "secrets_to_env_variables", dict) self.storage = check.dict_param(storage, "storage") check.invariant( local_dagster_job_package_path is not None or local_pipeline_package_path is not None, "Missing config: need to provide either 'local_dagster_job_package_path' or 'local_pipeline_package_path' config entry", ) check.invariant( local_dagster_job_package_path is None or local_pipeline_package_path is None, "Error in config: Provided both 'local_dagster_job_package_path' and 'local_pipeline_package_path' entries. Need to specify one or the other.", ) self.local_dagster_job_package_path = check.str_param( local_pipeline_package_path or local_dagster_job_package_path, "local_dagster_job_package_path", ) self.staging_prefix = check.str_param(staging_prefix, "staging_prefix") check.invariant(staging_prefix.startswith("/"), "staging_prefix must be an absolute path") self.wait_for_logs = check.bool_param(wait_for_logs, "wait_for_logs") self.databricks_runner = DatabricksJobRunner( host=databricks_host, token=databricks_token, poll_interval_sec=poll_interval_sec, max_wait_time_sec=max_completion_wait_time_seconds, ) def launch_step(self, step_context, prior_attempts_count): step_run_ref = step_context_to_step_run_ref( step_context, prior_attempts_count, self.local_dagster_job_package_path) run_id = step_context.pipeline_run.run_id log = step_context.log step_key = step_run_ref.step_key self._upload_artifacts(log, step_run_ref, run_id, step_key) task = self._get_databricks_task(run_id, step_key) databricks_run_id = self.databricks_runner.submit_run( self.run_config, task) try: # If this is being called within a `capture_interrupts` context, allow interrupts while # waiting for the execution to complete, so that we can terminate slow or hanging steps with raise_execution_interrupts(): yield from self.step_events_iterator(step_context, step_key, databricks_run_id) finally: self.log_compute_logs(log, run_id, step_key) # this is somewhat obsolete if self.wait_for_logs: self._log_logs_from_cluster(log, databricks_run_id) def log_compute_logs(self, log, run_id, step_key): stdout = self.databricks_runner.client.read_file( self._dbfs_path(run_id, step_key, "stdout")).decode() stderr = self.databricks_runner.client.read_file( self._dbfs_path(run_id, step_key, "stderr")).decode() log.info(f"Captured stdout for step {step_key}:") log.info(stdout) log.info(f"Captured stderr for step {step_key}:") log.info(stderr) def step_events_iterator(self, step_context, step_key: str, databricks_run_id: int): """The launched Databricks job writes all event records to a specific dbfs file. This iterator regularly reads the contents of the file, adds any events that have not yet been seen to the instance, and yields any DagsterEvents. By doing this, we simulate having the remote Databricks process able to directly write to the local DagsterInstance. Importantly, this means that timestamps (and all other record properties) will be sourced from the Databricks process, rather than recording when this process happens to log them. """ check.int_param(databricks_run_id, "databricks_run_id") processed_events = 0 start = time.time() done = False step_context.log.info("Waiting for Databricks run %s to complete..." % databricks_run_id) while not done: with raise_execution_interrupts(): step_context.log.debug( "Waiting %.1f seconds...", self.databricks_runner.poll_interval_sec) time.sleep(self.databricks_runner.poll_interval_sec) try: done = poll_run_state( self.databricks_runner.client, step_context.log, start, databricks_run_id, self.databricks_runner.max_wait_time_sec, ) finally: all_events = self.get_step_events(step_context.run_id, step_key) # we get all available records on each poll, but we only want to process the # ones we haven't seen before for event in all_events[processed_events:]: # write each event from the DataBricks instance to the local instance step_context.instance.handle_new_event(event) if event.is_dagster_event: yield event.dagster_event processed_events = len(all_events) step_context.log.info(f"Databricks run {databricks_run_id} completed.") def get_step_events(self, run_id: str, step_key: str): path = self._dbfs_path(run_id, step_key, PICKLED_EVENTS_FILE_NAME) def _get_step_records(): serialized_records = self.databricks_runner.client.read_file(path) if not serialized_records: return [] return deserialize_value(pickle.loads(serialized_records)) try: # reading from dbfs while it writes can be flaky # allow for retry if we get malformed data return backoff( fn=_get_step_records, retry_on=(pickle.UnpicklingError, ), max_retries=2, ) # if you poll before the Databricks process has had a chance to create the file, # we expect to get this error except HTTPError as e: if e.response.json().get( "error_code") == "RESOURCE_DOES_NOT_EXIST": return [] return [] def _get_databricks_task(self, run_id, step_key): """Construct the 'task' parameter to be submitted to the Databricks API. This will create a 'spark_python_task' dict where `python_file` is a path on DBFS pointing to the 'databricks_step_main.py' file, and `parameters` is an array with a single element, a path on DBFS pointing to the picked `step_run_ref` data. See https://docs.databricks.com/dev-tools/api/latest/jobs.html#jobssparkpythontask. """ python_file = self._dbfs_path(run_id, step_key, self._main_file_name()) parameters = [ self._internal_dbfs_path(run_id, step_key, PICKLED_STEP_RUN_REF_FILE_NAME), self._internal_dbfs_path(run_id, step_key, PICKLED_CONFIG_FILE_NAME), self._internal_dbfs_path(run_id, step_key, CODE_ZIP_NAME), ] return { "spark_python_task": { "python_file": python_file, "parameters": parameters } } def _upload_artifacts(self, log, step_run_ref, run_id, step_key): """Upload the step run ref and pyspark code to DBFS to run as a job.""" log.info("Uploading main file to DBFS") main_local_path = self._main_file_local_path() with open(main_local_path, "rb") as infile: self.databricks_runner.client.put_file( infile, self._dbfs_path(run_id, step_key, self._main_file_name())) log.info("Uploading dagster job to DBFS") with tempfile.TemporaryDirectory() as temp_dir: # Zip and upload package containing dagster job zip_local_path = os.path.join(temp_dir, CODE_ZIP_NAME) build_pyspark_zip(zip_local_path, self.local_dagster_job_package_path) with open(zip_local_path, "rb") as infile: self.databricks_runner.client.put_file( infile, self._dbfs_path(run_id, step_key, CODE_ZIP_NAME)) log.info("Uploading step run ref file to DBFS") step_pickle_file = io.BytesIO() pickle.dump(step_run_ref, step_pickle_file) step_pickle_file.seek(0) self.databricks_runner.client.put_file( step_pickle_file, self._dbfs_path(run_id, step_key, PICKLED_STEP_RUN_REF_FILE_NAME), ) databricks_config = DatabricksConfig( storage=self.storage, secrets=self.secrets, ) log.info("Uploading Databricks configuration to DBFS") databricks_config_file = io.BytesIO() pickle.dump(databricks_config, databricks_config_file) databricks_config_file.seek(0) self.databricks_runner.client.put_file( databricks_config_file, self._dbfs_path(run_id, step_key, PICKLED_CONFIG_FILE_NAME), ) def _log_logs_from_cluster(self, log, run_id): logs = self.databricks_runner.retrieve_logs_for_run_id(log, run_id) if logs is None: return stdout, stderr = logs if stderr: log.info(stderr) if stdout: log.info(stdout) def _main_file_name(self): return os.path.basename(self._main_file_local_path()) def _main_file_local_path(self): return databricks_step_main.__file__ def _sanitize_step_key(self, step_key: str) -> str: # step_keys of dynamic steps contain brackets, which are invalid characters return step_key.replace("[", "__").replace("]", "__") def _dbfs_path(self, run_id, step_key, filename): path = "/".join([ self.staging_prefix, run_id, self._sanitize_step_key(step_key), os.path.basename(filename), ]) return "dbfs://{}".format(path) def _internal_dbfs_path(self, run_id, step_key, filename): """Scripts running on Databricks should access DBFS at /dbfs/.""" path = "/".join([ self.staging_prefix, run_id, self._sanitize_step_key(step_key), os.path.basename(filename), ]) return "/dbfs/{}".format(path)
class DatabricksPySparkStepLauncher(StepLauncher): def __init__( self, run_config, databricks_host, databricks_token, secrets_to_env_variables, storage, local_pipeline_package_path, staging_prefix, wait_for_logs, max_completion_wait_time_seconds, ): self.run_config = check.dict_param(run_config, "run_config") self.databricks_host = check.str_param(databricks_host, "databricks_host") self.databricks_token = check.str_param(databricks_token, "databricks_token") self.secrets = check.list_param(secrets_to_env_variables, "secrets_to_env_variables", dict) self.storage = check.dict_param(storage, "storage") self.local_pipeline_package_path = check.str_param( local_pipeline_package_path, "local_pipeline_package_path") self.staging_prefix = check.str_param(staging_prefix, "staging_prefix") check.invariant(staging_prefix.startswith("/"), "staging_prefix must be an absolute path") self.wait_for_logs = check.bool_param(wait_for_logs, "wait_for_logs") self.databricks_runner = DatabricksJobRunner( host=databricks_host, token=databricks_token, max_wait_time_sec=max_completion_wait_time_seconds, ) def launch_step(self, step_context, prior_attempts_count): step_run_ref = step_context_to_step_run_ref( step_context, prior_attempts_count, self.local_pipeline_package_path) run_id = step_context.pipeline_run.run_id log = step_context.log step_key = step_run_ref.step_key self._upload_artifacts(log, step_run_ref, run_id, step_key) task = self._get_databricks_task(run_id, step_key) databricks_run_id = self.databricks_runner.submit_run( self.run_config, task) try: # If this is being called within a `capture_interrupts` context, allow interrupts while # waiting for the execution to complete, so that we can terminate slow or hanging steps with raise_execution_interrupts(): self.databricks_runner.wait_for_run_to_complete( log, databricks_run_id) finally: if self.wait_for_logs: self._log_logs_from_cluster(log, databricks_run_id) for event in self.get_step_events(run_id, step_key): log_step_event(step_context, event) yield event def get_step_events(self, run_id, step_key): path = self._dbfs_path(run_id, step_key, PICKLED_EVENTS_FILE_NAME) events_data = self.databricks_runner.client.read_file(path) return deserialize_value(pickle.loads(events_data)) def _get_databricks_task(self, run_id, step_key): """Construct the 'task' parameter to be submitted to the Databricks API. This will create a 'spark_python_task' dict where `python_file` is a path on DBFS pointing to the 'databricks_step_main.py' file, and `parameters` is an array with a single element, a path on DBFS pointing to the picked `step_run_ref` data. See https://docs.databricks.com/dev-tools/api/latest/jobs.html#jobssparkpythontask. """ python_file = self._dbfs_path(run_id, step_key, self._main_file_name()) parameters = [ self._internal_dbfs_path(run_id, step_key, PICKLED_STEP_RUN_REF_FILE_NAME), self._internal_dbfs_path(run_id, step_key, PICKLED_CONFIG_FILE_NAME), self._internal_dbfs_path(run_id, step_key, CODE_ZIP_NAME), ] return { "spark_python_task": { "python_file": python_file, "parameters": parameters } } def _upload_artifacts(self, log, step_run_ref, run_id, step_key): """Upload the step run ref and pyspark code to DBFS to run as a job.""" log.info("Uploading main file to DBFS") main_local_path = self._main_file_local_path() with open(main_local_path, "rb") as infile: self.databricks_runner.client.put_file( infile, self._dbfs_path(run_id, step_key, self._main_file_name())) log.info("Uploading pipeline to DBFS") with tempfile.TemporaryDirectory() as temp_dir: # Zip and upload package containing pipeline zip_local_path = os.path.join(temp_dir, CODE_ZIP_NAME) build_pyspark_zip(zip_local_path, self.local_pipeline_package_path) with open(zip_local_path, "rb") as infile: self.databricks_runner.client.put_file( infile, self._dbfs_path(run_id, step_key, CODE_ZIP_NAME)) log.info("Uploading step run ref file to DBFS") step_pickle_file = io.BytesIO() pickle.dump(step_run_ref, step_pickle_file) step_pickle_file.seek(0) self.databricks_runner.client.put_file( step_pickle_file, self._dbfs_path(run_id, step_key, PICKLED_STEP_RUN_REF_FILE_NAME), ) databricks_config = DatabricksConfig( storage=self.storage, secrets=self.secrets, ) log.info("Uploading Databricks configuration to DBFS") databricks_config_file = io.BytesIO() pickle.dump(databricks_config, databricks_config_file) databricks_config_file.seek(0) self.databricks_runner.client.put_file( databricks_config_file, self._dbfs_path(run_id, step_key, PICKLED_CONFIG_FILE_NAME), ) def _log_logs_from_cluster(self, log, run_id): logs = self.databricks_runner.retrieve_logs_for_run_id(log, run_id) if logs is None: return stdout, stderr = logs if stderr: log.info(stderr) if stdout: log.info(stdout) def _main_file_name(self): return os.path.basename(self._main_file_local_path()) def _main_file_local_path(self): return databricks_step_main.__file__ def _dbfs_path(self, run_id, step_key, filename): path = "/".join([ self.staging_prefix, run_id, step_key, os.path.basename(filename) ]) return "dbfs://{}".format(path) def _internal_dbfs_path(self, run_id, step_key, filename): """Scripts running on Databricks should access DBFS at /dbfs/.""" path = "/".join([ self.staging_prefix, run_id, step_key, os.path.basename(filename) ]) return "/dbfs/{}".format(path)