Exemple #1
0
    def test_exec_failure(self, db_mock_class):
        """
        Test the execute function in case where the run failed.
        """
        run = {
            'notebook_params': NOTEBOOK_PARAMS,
            'notebook_task': NOTEBOOK_TASK,
            'jar_params': JAR_PARAMS
        }
        op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
        db_mock = db_mock_class.return_value
        db_mock.run_now.return_value = 1
        db_mock.get_run_state.return_value = RunState('TERMINATED', 'FAILED',
                                                      '')

        with self.assertRaises(AirflowException):
            op.execute(None)

        expected = databricks_operator._deep_string_coerce({
            'notebook_params': NOTEBOOK_PARAMS,
            'notebook_task': NOTEBOOK_TASK,
            'jar_params': JAR_PARAMS,
            'job_id': JOB_ID
        })
        db_mock_class.assert_called_once_with(
            DEFAULT_CONN_ID,
            retry_limit=op.databricks_retry_limit,
            retry_delay=op.databricks_retry_delay)
        db_mock.run_now.assert_called_once_with(expected)
        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
        db_mock.get_run_state.assert_called_once_with(RUN_ID)
        self.assertEqual(RUN_ID, op.run_id)
Exemple #2
0
def _handle_databricks_operator_execution(operator, hook, log,
                                          context) -> None:
    """
    Handles the Airflow + Databricks lifecycle logic for a Databricks operator

    :param operator: Databricks operator being handled
    :param context: Airflow context
    """
    if operator.do_xcom_push and context is not None:
        context['ti'].xcom_push(key=XCOM_RUN_ID_KEY, value=operator.run_id)
    log.info('Run submitted with run_id: %s', operator.run_id)
    run_page_url = hook.get_run_page_url(operator.run_id)
    if operator.do_xcom_push and context is not None:
        context['ti'].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=run_page_url)

    if operator.wait_for_termination:
        while True:
            run_info = hook.get_run(operator.run_id)
            run_state = RunState(**run_info['state'])
            if run_state.is_terminal:
                if run_state.is_successful:
                    log.info('%s completed successfully.', operator.task_id)
                    log.info('View run status, Spark UI, and logs at %s',
                             run_page_url)
                    return
                else:
                    if run_state.result_state == "FAILED":
                        task_run_id = None
                        if 'tasks' in run_info:
                            for task in run_info['tasks']:
                                if task.get("state",
                                            {}).get("result_state",
                                                    "") == "FAILED":
                                    task_run_id = task["run_id"]
                        if task_run_id is not None:
                            run_output = hook.get_run_output(task_run_id)
                            if 'error' in run_output:
                                notebook_error = run_output['error']
                            else:
                                notebook_error = run_state.state_message
                        else:
                            notebook_error = run_state.state_message
                        error_message = (
                            f'{operator.task_id} failed with terminal state: {run_state} '
                            f'and with the error {notebook_error}')
                    else:
                        error_message = (
                            f'{operator.task_id} failed with terminal state: {run_state} '
                            f'and with the error {run_state.state_message}')
                    raise AirflowException(error_message)

            else:
                log.info('%s in run state: %s', operator.task_id, run_state)
                log.info('View run status, Spark UI, and logs at %s',
                         run_page_url)
                log.info('Sleeping for %s seconds.',
                         operator.polling_period_seconds)
                time.sleep(operator.polling_period_seconds)
    else:
        log.info('View run status, Spark UI, and logs at %s', run_page_url)
Exemple #3
0
    def test_exec_failure(self, db_mock_class):
        """
        Test the execute function in case where the run failed.
        """
        run = {
            'new_cluster': NEW_CLUSTER,
            'notebook_task': NOTEBOOK_TASK,
        }
        op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
        db_mock = db_mock_class.return_value
        db_mock.submit_run.return_value = 1
        db_mock.get_run_state.return_value = RunState('TERMINATED', 'FAILED',
                                                      '')

        with self.assertRaises(AirflowException):
            op.execute(None)

        expected = databricks_operator._deep_string_coerce({
            'new_cluster': NEW_CLUSTER,
            'notebook_task': NOTEBOOK_TASK,
            'run_name': TASK_ID,
        })
        db_mock_class.assert_called_once_with(
            DEFAULT_CONN_ID,
            retry_limit=op.databricks_retry_limit,
            retry_delay=op.databricks_retry_delay)
        db_mock.submit_run.assert_called_once_with(expected)
        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
        db_mock.get_run_state.assert_called_once_with(RUN_ID)
        self.assertEqual(RUN_ID, op.run_id)
Exemple #4
0
    def test_exec_success(self, db_mock_class):
        """
        Test the execute function in case where the run is successful.
        """
        run = {
            'notebook_params': NOTEBOOK_PARAMS,
            'notebook_task': NOTEBOOK_TASK,
            'jar_params': JAR_PARAMS
        }
        op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
        db_mock = db_mock_class.return_value
        db_mock.run_now.return_value = 1
        db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS',
                                                      '')

        op.execute(None)

        expected = databricks_operator._deep_string_coerce({
            'notebook_params': NOTEBOOK_PARAMS,
            'notebook_task': NOTEBOOK_TASK,
            'jar_params': JAR_PARAMS,
            'job_id': JOB_ID,
        })

        db_mock_class.assert_called_once_with(
            DEFAULT_CONN_ID,
            retry_limit=op.databricks_retry_limit,
            retry_delay=op.databricks_retry_delay)
        db_mock.run_now.assert_called_once_with(expected)
        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
        db_mock.get_run_state.assert_called_once_with(RUN_ID)
        assert RUN_ID == op.run_id
Exemple #5
0
    def test_exec_success(self, db_mock_class):
        """
        Test the execute function in case where the run is successful.
        """
        run = {
            'new_cluster': NEW_CLUSTER,
            'notebook_task': NOTEBOOK_TASK,
        }
        op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
        db_mock = db_mock_class.return_value
        db_mock.submit_run.return_value = 1
        db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS',
                                                      '')

        op.execute(None)

        expected = databricks_operator._deep_string_coerce({
            'new_cluster': NEW_CLUSTER,
            'notebook_task': NOTEBOOK_TASK,
            'run_name': TASK_ID
        })
        db_mock_class.assert_called_once_with(
            DEFAULT_CONN_ID,
            retry_limit=op.databricks_retry_limit,
            retry_delay=op.databricks_retry_delay)

        db_mock.submit_run.assert_called_once_with(expected)
        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
        db_mock.get_run_state.assert_called_once_with(RUN_ID)
        assert RUN_ID == op.run_id
Exemple #6
0
def validate_trigger_event(event: dict):
    """
    Validates correctness of the event
    received from :class:`~airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger`
    """
    keys_to_check = ['run_id', 'run_page_url', 'run_state']
    for key in keys_to_check:
        if key not in event:
            raise AirflowException(
                f'Could not find `{key}` in the event: {event}')

    try:
        RunState.from_json(event['run_state'])
    except Exception:
        raise AirflowException(
            f'Run state returned by the Trigger is incorrect: {event["run_state"]}'
        )
Exemple #7
0
def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger) -> None:
    validate_trigger_event(event)
    run_state = RunState.from_json(event['run_state'])
    run_page_url = event['run_page_url']
    log.info(f'View run status, Spark UI, and logs at {run_page_url}')

    if run_state.is_successful:
        log.info('Job run completed successfully.')
        return
    else:
        error_message = f'Job run failed with terminal state: {run_state}'
        raise AirflowException(error_message)
Exemple #8
0
    def test_get_run_state(self, mock_requests):
        mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE

        run_state = self.hook.get_run_state(RUN_ID)

        self.assertEqual(
            run_state, RunState(LIFE_CYCLE_STATE, RESULT_STATE, STATE_MESSAGE))
        mock_requests.get.assert_called_once_with(
            get_run_endpoint(HOST),
            json={'run_id': RUN_ID},
            auth=(LOGIN, PASSWORD),
            headers=USER_AGENT_HEADER,
            timeout=self.hook.timeout_seconds)
Exemple #9
0
 def test_is_successful(self):
     run_state = RunState('TERMINATED', 'SUCCESS', '')
     assert run_state.is_successful
Exemple #10
0
 def test_is_terminal_with_nonexistent_life_cycle_state(self):
     run_state = RunState('blah', '', '')
     with pytest.raises(AirflowException):
         run_state.is_terminal
Exemple #11
0
 def test_is_terminal_false(self):
     non_terminal_states = ['PENDING', 'RUNNING', 'TERMINATING']
     for state in non_terminal_states:
         run_state = RunState(state, '', '')
         assert not run_state.is_terminal
Exemple #12
0
 def test_is_terminal_true(self):
     terminal_states = ['TERMINATED', 'SKIPPED', 'INTERNAL_ERROR']
     for state in terminal_states:
         run_state = RunState(state, '', '')
         assert run_state.is_terminal