Example #1
0
def test_databricks_submit_job_existing_cluster(mock_submit_run,
                                                databricks_run_config):
    mock_submit_run.return_value = {'run_id': 1}

    runner = DatabricksJobRunner(HOST, TOKEN)
    task = databricks_run_config.pop('task')
    runner.submit_run(databricks_run_config, task)
    mock_submit_run.assert_called_once_with(
        run_name=databricks_run_config['run_name'],
        new_cluster=None,
        existing_cluster_id=databricks_run_config['cluster']['existing'],
        spark_jar_task=task['spark_jar_task'],
        libraries=[
            {
                'pypi': {
                    'package': 'dagster=={}'.format(dagster.__version__)
                }
            },
            {
                'pypi': {
                    'package':
                    'dagster_databricks=={}'.format(dagster.__version__)
                }
            },
            {
                'pypi': {
                    'package':
                    'dagster_pyspark=={}'.format(dagster.__version__)
                }
            },
        ],
    )
Example #2
0
def test_databricks_submit_job_new_cluster(mock_submit_run, databricks_run_config):
    mock_submit_run.return_value = {"run_id": 1}

    runner = DatabricksJobRunner(HOST, TOKEN)

    NEW_CLUSTER = {
        "size": {"num_workers": 1},
        "spark_version": "6.5.x-scala2.11",
        "nodes": {"node_types": {"node_type_id": "Standard_DS3_v2"}},
    }
    databricks_run_config["cluster"] = {"new": NEW_CLUSTER}

    task = databricks_run_config.pop("task")
    runner.submit_run(databricks_run_config, task)
    mock_submit_run.assert_called_once_with(
        run_name=databricks_run_config["run_name"],
        new_cluster={
            "num_workers": 1,
            "spark_version": "6.5.x-scala2.11",
            "node_type_id": "Standard_DS3_v2",
            "custom_tags": [{"key": "__dagster_version", "value": dagster.__version__}],
        },
        existing_cluster_id=None,
        spark_jar_task=task["spark_jar_task"],
        libraries=[
            {"pypi": {"package": "dagster=={}".format(dagster.__version__)}},
            {"pypi": {"package": "dagster_databricks=={}".format(dagster.__version__)}},
            {"pypi": {"package": "dagster_pyspark=={}".format(dagster.__version__)}},
        ],
    )
Example #3
0
def test_databricks_submit_job_existing_cluster(mock_submit_run,
                                                databricks_run_config):
    mock_submit_run.return_value = {"run_id": 1}

    runner = DatabricksJobRunner(HOST, TOKEN)
    task = databricks_run_config.pop("task")
    runner.submit_run(databricks_run_config, task)
    mock_submit_run.assert_called_once_with(
        run_name=databricks_run_config["run_name"],
        new_cluster=None,
        existing_cluster_id=databricks_run_config["cluster"]["existing"],
        spark_jar_task=task["spark_jar_task"],
        libraries=[
            {
                "pypi": {
                    "package": "dagster=={}".format(dagster.__version__)
                }
            },
            {
                "pypi": {
                    "package":
                    "dagster_databricks=={}".format(dagster.__version__)
                }
            },
            {
                "pypi": {
                    "package":
                    "dagster_pyspark=={}".format(dagster.__version__)
                }
            },
        ],
    )
Example #4
0
def test_databricks_submit_job_new_cluster(mock_submit_run,
                                           databricks_run_config):
    mock_submit_run.return_value = {'run_id': 1}

    runner = DatabricksJobRunner(HOST, TOKEN)

    NEW_CLUSTER = {
        'size': {
            'num_workers': 1
        },
        'spark_version': '6.5.x-scala2.11',
        'nodes': {
            'node_types': {
                'node_type_id': 'Standard_DS3_v2'
            }
        },
    }
    databricks_run_config['cluster'] = {'new': NEW_CLUSTER}

    task = databricks_run_config.pop('task')
    runner.submit_run(databricks_run_config, task)
    mock_submit_run.assert_called_once_with(
        run_name=databricks_run_config['run_name'],
        new_cluster={
            'num_workers':
            1,
            'spark_version':
            '6.5.x-scala2.11',
            'node_type_id':
            'Standard_DS3_v2',
            'custom_tags': [{
                'key': '__dagster_version',
                'value': dagster.__version__
            }],
        },
        existing_cluster_id=None,
        spark_jar_task=task['spark_jar_task'],
        libraries=[
            {
                'pypi': {
                    'package': 'dagster=={}'.format(dagster.__version__)
                }
            },
            {
                'pypi': {
                    'package':
                    'dagster_databricks=={}'.format(dagster.__version__)
                }
            },
            {
                'pypi': {
                    'package':
                    'dagster_pyspark=={}'.format(dagster.__version__)
                }
            },
        ],
    )
Example #5
0
def test_databricks_wait_for_run(mock_submit_run, databricks_run_config):
    mock_submit_run.return_value = {'run_id': 1}

    context = create_test_pipeline_execution_context()
    runner = DatabricksJobRunner(HOST, TOKEN, poll_interval_sec=0.01)
    task = databricks_run_config.pop('task')
    databricks_run_id = runner.submit_run(databricks_run_config, task)

    calls = {
        'num_calls':
        0,
        'final_state':
        DatabricksRunState(
            DatabricksRunLifeCycleState.Terminated,
            DatabricksRunResultState.Success,
            'Finished',
        ),
    }

    def new_get_run_state(_run_id):
        calls['num_calls'] += 1

        if calls['num_calls'] == 1:
            return DatabricksRunState(
                DatabricksRunLifeCycleState.Pending,
                None,
                None,
            )
        elif calls['num_calls'] == 2:
            return DatabricksRunState(
                DatabricksRunLifeCycleState.Running,
                None,
                None,
            )
        else:
            return calls['final_state']

    with mock.patch.object(runner.client,
                           'get_run_state',
                           new=new_get_run_state):
        runner.wait_for_run_to_complete(context.log, databricks_run_id)

    calls['num_calls'] = 0
    calls['final_state'] = DatabricksRunState(
        DatabricksRunLifeCycleState.Terminated,
        DatabricksRunResultState.Failed,
        'Failed',
    )
    with pytest.raises(DatabricksError) as exc_info:
        with mock.patch.object(runner.client,
                               'get_run_state',
                               new=new_get_run_state):
            runner.wait_for_run_to_complete(context.log, databricks_run_id)
    assert 'Run 1 failed with result state' in str(exc_info.value)
Example #6
0
class DatabricksPySparkStepLauncher(StepLauncher):
    def __init__(
        self,
        run_config,
        databricks_host,
        databricks_token,
        secrets_to_env_variables,
        storage,
        staging_prefix,
        wait_for_logs,
        max_completion_wait_time_seconds,
        poll_interval_sec=5,
        local_pipeline_package_path=None,
        local_dagster_job_package_path=None,
    ):
        self.run_config = check.dict_param(run_config, "run_config")
        self.databricks_host = check.str_param(databricks_host,
                                               "databricks_host")
        self.databricks_token = check.str_param(databricks_token,
                                                "databricks_token")
        self.secrets = check.list_param(secrets_to_env_variables,
                                        "secrets_to_env_variables", dict)
        self.storage = check.dict_param(storage, "storage")
        check.invariant(
            local_dagster_job_package_path is not None
            or local_pipeline_package_path is not None,
            "Missing config: need to provide either 'local_dagster_job_package_path' or 'local_pipeline_package_path' config entry",
        )
        check.invariant(
            local_dagster_job_package_path is None
            or local_pipeline_package_path is None,
            "Error in config: Provided both 'local_dagster_job_package_path' and 'local_pipeline_package_path' entries. Need to specify one or the other.",
        )
        self.local_dagster_job_package_path = check.str_param(
            local_pipeline_package_path or local_dagster_job_package_path,
            "local_dagster_job_package_path",
        )
        self.staging_prefix = check.str_param(staging_prefix, "staging_prefix")
        check.invariant(staging_prefix.startswith("/"),
                        "staging_prefix must be an absolute path")
        self.wait_for_logs = check.bool_param(wait_for_logs, "wait_for_logs")

        self.databricks_runner = DatabricksJobRunner(
            host=databricks_host,
            token=databricks_token,
            poll_interval_sec=poll_interval_sec,
            max_wait_time_sec=max_completion_wait_time_seconds,
        )

    def launch_step(self, step_context, prior_attempts_count):
        step_run_ref = step_context_to_step_run_ref(
            step_context, prior_attempts_count,
            self.local_dagster_job_package_path)
        run_id = step_context.pipeline_run.run_id
        log = step_context.log

        step_key = step_run_ref.step_key
        self._upload_artifacts(log, step_run_ref, run_id, step_key)

        task = self._get_databricks_task(run_id, step_key)
        databricks_run_id = self.databricks_runner.submit_run(
            self.run_config, task)

        try:
            # If this is being called within a `capture_interrupts` context, allow interrupts while
            # waiting for the  execution to complete, so that we can terminate slow or hanging steps
            with raise_execution_interrupts():
                yield from self.step_events_iterator(step_context, step_key,
                                                     databricks_run_id)
        finally:
            self.log_compute_logs(log, run_id, step_key)
            # this is somewhat obsolete
            if self.wait_for_logs:
                self._log_logs_from_cluster(log, databricks_run_id)

    def log_compute_logs(self, log, run_id, step_key):
        stdout = self.databricks_runner.client.read_file(
            self._dbfs_path(run_id, step_key, "stdout")).decode()
        stderr = self.databricks_runner.client.read_file(
            self._dbfs_path(run_id, step_key, "stderr")).decode()
        log.info(f"Captured stdout for step {step_key}:")
        log.info(stdout)
        log.info(f"Captured stderr for step {step_key}:")
        log.info(stderr)

    def step_events_iterator(self, step_context, step_key: str,
                             databricks_run_id: int):
        """The launched Databricks job writes all event records to a specific dbfs file. This iterator
        regularly reads the contents of the file, adds any events that have not yet been seen to
        the instance, and yields any DagsterEvents.

        By doing this, we simulate having the remote Databricks process able to directly write to
        the local DagsterInstance. Importantly, this means that timestamps (and all other record
        properties) will be sourced from the Databricks process, rather than recording when this
        process happens to log them.
        """

        check.int_param(databricks_run_id, "databricks_run_id")
        processed_events = 0
        start = time.time()
        done = False
        step_context.log.info("Waiting for Databricks run %s to complete..." %
                              databricks_run_id)
        while not done:
            with raise_execution_interrupts():
                step_context.log.debug(
                    "Waiting %.1f seconds...",
                    self.databricks_runner.poll_interval_sec)
                time.sleep(self.databricks_runner.poll_interval_sec)
                try:
                    done = poll_run_state(
                        self.databricks_runner.client,
                        step_context.log,
                        start,
                        databricks_run_id,
                        self.databricks_runner.max_wait_time_sec,
                    )
                finally:
                    all_events = self.get_step_events(step_context.run_id,
                                                      step_key)
                    # we get all available records on each poll, but we only want to process the
                    # ones we haven't seen before
                    for event in all_events[processed_events:]:
                        # write each event from the DataBricks instance to the local instance
                        step_context.instance.handle_new_event(event)
                        if event.is_dagster_event:
                            yield event.dagster_event
                    processed_events = len(all_events)

        step_context.log.info(f"Databricks run {databricks_run_id} completed.")

    def get_step_events(self, run_id: str, step_key: str):
        path = self._dbfs_path(run_id, step_key, PICKLED_EVENTS_FILE_NAME)

        def _get_step_records():
            serialized_records = self.databricks_runner.client.read_file(path)
            if not serialized_records:
                return []
            return deserialize_value(pickle.loads(serialized_records))

        try:
            # reading from dbfs while it writes can be flaky
            # allow for retry if we get malformed data
            return backoff(
                fn=_get_step_records,
                retry_on=(pickle.UnpicklingError, ),
                max_retries=2,
            )
        # if you poll before the Databricks process has had a chance to create the file,
        # we expect to get this error
        except HTTPError as e:
            if e.response.json().get(
                    "error_code") == "RESOURCE_DOES_NOT_EXIST":
                return []

        return []

    def _get_databricks_task(self, run_id, step_key):
        """Construct the 'task' parameter to  be submitted to the Databricks API.

        This will create a 'spark_python_task' dict where `python_file` is a path on DBFS
        pointing to the 'databricks_step_main.py' file, and `parameters` is an array with a single
        element, a path on DBFS pointing to the picked `step_run_ref` data.

        See https://docs.databricks.com/dev-tools/api/latest/jobs.html#jobssparkpythontask.
        """
        python_file = self._dbfs_path(run_id, step_key, self._main_file_name())
        parameters = [
            self._internal_dbfs_path(run_id, step_key,
                                     PICKLED_STEP_RUN_REF_FILE_NAME),
            self._internal_dbfs_path(run_id, step_key,
                                     PICKLED_CONFIG_FILE_NAME),
            self._internal_dbfs_path(run_id, step_key, CODE_ZIP_NAME),
        ]
        return {
            "spark_python_task": {
                "python_file": python_file,
                "parameters": parameters
            }
        }

    def _upload_artifacts(self, log, step_run_ref, run_id, step_key):
        """Upload the step run ref and pyspark code to DBFS to run as a job."""

        log.info("Uploading main file to DBFS")
        main_local_path = self._main_file_local_path()
        with open(main_local_path, "rb") as infile:
            self.databricks_runner.client.put_file(
                infile,
                self._dbfs_path(run_id, step_key, self._main_file_name()))

        log.info("Uploading dagster job to DBFS")
        with tempfile.TemporaryDirectory() as temp_dir:
            # Zip and upload package containing dagster job
            zip_local_path = os.path.join(temp_dir, CODE_ZIP_NAME)
            build_pyspark_zip(zip_local_path,
                              self.local_dagster_job_package_path)
            with open(zip_local_path, "rb") as infile:
                self.databricks_runner.client.put_file(
                    infile, self._dbfs_path(run_id, step_key, CODE_ZIP_NAME))

        log.info("Uploading step run ref file to DBFS")
        step_pickle_file = io.BytesIO()

        pickle.dump(step_run_ref, step_pickle_file)
        step_pickle_file.seek(0)
        self.databricks_runner.client.put_file(
            step_pickle_file,
            self._dbfs_path(run_id, step_key, PICKLED_STEP_RUN_REF_FILE_NAME),
        )

        databricks_config = DatabricksConfig(
            storage=self.storage,
            secrets=self.secrets,
        )
        log.info("Uploading Databricks configuration to DBFS")
        databricks_config_file = io.BytesIO()

        pickle.dump(databricks_config, databricks_config_file)
        databricks_config_file.seek(0)
        self.databricks_runner.client.put_file(
            databricks_config_file,
            self._dbfs_path(run_id, step_key, PICKLED_CONFIG_FILE_NAME),
        )

    def _log_logs_from_cluster(self, log, run_id):
        logs = self.databricks_runner.retrieve_logs_for_run_id(log, run_id)
        if logs is None:
            return
        stdout, stderr = logs
        if stderr:
            log.info(stderr)
        if stdout:
            log.info(stdout)

    def _main_file_name(self):
        return os.path.basename(self._main_file_local_path())

    def _main_file_local_path(self):
        return databricks_step_main.__file__

    def _sanitize_step_key(self, step_key: str) -> str:
        # step_keys of dynamic steps contain brackets, which are invalid characters
        return step_key.replace("[", "__").replace("]", "__")

    def _dbfs_path(self, run_id, step_key, filename):
        path = "/".join([
            self.staging_prefix,
            run_id,
            self._sanitize_step_key(step_key),
            os.path.basename(filename),
        ])
        return "dbfs://{}".format(path)

    def _internal_dbfs_path(self, run_id, step_key, filename):
        """Scripts running on Databricks should access DBFS at /dbfs/."""
        path = "/".join([
            self.staging_prefix,
            run_id,
            self._sanitize_step_key(step_key),
            os.path.basename(filename),
        ])
        return "/dbfs/{}".format(path)
class DatabricksPySparkStepLauncher(StepLauncher):
    def __init__(
        self,
        run_config,
        databricks_host,
        databricks_token,
        secrets_to_env_variables,
        storage,
        local_pipeline_package_path,
        staging_prefix,
        wait_for_logs,
        max_completion_wait_time_seconds,
    ):
        self.run_config = check.dict_param(run_config, "run_config")
        self.databricks_host = check.str_param(databricks_host,
                                               "databricks_host")
        self.databricks_token = check.str_param(databricks_token,
                                                "databricks_token")
        self.secrets = check.list_param(secrets_to_env_variables,
                                        "secrets_to_env_variables", dict)
        self.storage = check.dict_param(storage, "storage")
        self.local_pipeline_package_path = check.str_param(
            local_pipeline_package_path, "local_pipeline_package_path")
        self.staging_prefix = check.str_param(staging_prefix, "staging_prefix")
        check.invariant(staging_prefix.startswith("/"),
                        "staging_prefix must be an absolute path")
        self.wait_for_logs = check.bool_param(wait_for_logs, "wait_for_logs")

        self.databricks_runner = DatabricksJobRunner(
            host=databricks_host,
            token=databricks_token,
            max_wait_time_sec=max_completion_wait_time_seconds,
        )

    def launch_step(self, step_context, prior_attempts_count):
        step_run_ref = step_context_to_step_run_ref(
            step_context, prior_attempts_count,
            self.local_pipeline_package_path)
        run_id = step_context.pipeline_run.run_id
        log = step_context.log

        step_key = step_run_ref.step_key
        self._upload_artifacts(log, step_run_ref, run_id, step_key)

        task = self._get_databricks_task(run_id, step_key)
        databricks_run_id = self.databricks_runner.submit_run(
            self.run_config, task)

        try:
            # If this is being called within a `capture_interrupts` context, allow interrupts while
            # waiting for the  execution to complete, so that we can terminate slow or hanging steps
            with raise_execution_interrupts():
                self.databricks_runner.wait_for_run_to_complete(
                    log, databricks_run_id)
        finally:
            if self.wait_for_logs:
                self._log_logs_from_cluster(log, databricks_run_id)

        for event in self.get_step_events(run_id, step_key):
            log_step_event(step_context, event)
            yield event

    def get_step_events(self, run_id, step_key):
        path = self._dbfs_path(run_id, step_key, PICKLED_EVENTS_FILE_NAME)
        events_data = self.databricks_runner.client.read_file(path)
        return deserialize_value(pickle.loads(events_data))

    def _get_databricks_task(self, run_id, step_key):
        """Construct the 'task' parameter to  be submitted to the Databricks API.

        This will create a 'spark_python_task' dict where `python_file` is a path on DBFS
        pointing to the 'databricks_step_main.py' file, and `parameters` is an array with a single
        element, a path on DBFS pointing to the picked `step_run_ref` data.

        See https://docs.databricks.com/dev-tools/api/latest/jobs.html#jobssparkpythontask.
        """
        python_file = self._dbfs_path(run_id, step_key, self._main_file_name())
        parameters = [
            self._internal_dbfs_path(run_id, step_key,
                                     PICKLED_STEP_RUN_REF_FILE_NAME),
            self._internal_dbfs_path(run_id, step_key,
                                     PICKLED_CONFIG_FILE_NAME),
            self._internal_dbfs_path(run_id, step_key, CODE_ZIP_NAME),
        ]
        return {
            "spark_python_task": {
                "python_file": python_file,
                "parameters": parameters
            }
        }

    def _upload_artifacts(self, log, step_run_ref, run_id, step_key):
        """Upload the step run ref and pyspark code to DBFS to run as a job."""

        log.info("Uploading main file to DBFS")
        main_local_path = self._main_file_local_path()
        with open(main_local_path, "rb") as infile:
            self.databricks_runner.client.put_file(
                infile,
                self._dbfs_path(run_id, step_key, self._main_file_name()))

        log.info("Uploading pipeline to DBFS")
        with tempfile.TemporaryDirectory() as temp_dir:
            # Zip and upload package containing pipeline
            zip_local_path = os.path.join(temp_dir, CODE_ZIP_NAME)
            build_pyspark_zip(zip_local_path, self.local_pipeline_package_path)
            with open(zip_local_path, "rb") as infile:
                self.databricks_runner.client.put_file(
                    infile, self._dbfs_path(run_id, step_key, CODE_ZIP_NAME))

        log.info("Uploading step run ref file to DBFS")
        step_pickle_file = io.BytesIO()

        pickle.dump(step_run_ref, step_pickle_file)
        step_pickle_file.seek(0)
        self.databricks_runner.client.put_file(
            step_pickle_file,
            self._dbfs_path(run_id, step_key, PICKLED_STEP_RUN_REF_FILE_NAME),
        )

        databricks_config = DatabricksConfig(
            storage=self.storage,
            secrets=self.secrets,
        )
        log.info("Uploading Databricks configuration to DBFS")
        databricks_config_file = io.BytesIO()

        pickle.dump(databricks_config, databricks_config_file)
        databricks_config_file.seek(0)
        self.databricks_runner.client.put_file(
            databricks_config_file,
            self._dbfs_path(run_id, step_key, PICKLED_CONFIG_FILE_NAME),
        )

    def _log_logs_from_cluster(self, log, run_id):
        logs = self.databricks_runner.retrieve_logs_for_run_id(log, run_id)
        if logs is None:
            return
        stdout, stderr = logs
        if stderr:
            log.info(stderr)
        if stdout:
            log.info(stdout)

    def _main_file_name(self):
        return os.path.basename(self._main_file_local_path())

    def _main_file_local_path(self):
        return databricks_step_main.__file__

    def _dbfs_path(self, run_id, step_key, filename):
        path = "/".join([
            self.staging_prefix, run_id, step_key,
            os.path.basename(filename)
        ])
        return "dbfs://{}".format(path)

    def _internal_dbfs_path(self, run_id, step_key, filename):
        """Scripts running on Databricks should access DBFS at /dbfs/."""
        path = "/".join([
            self.staging_prefix, run_id, step_key,
            os.path.basename(filename)
        ])
        return "/dbfs/{}".format(path)