Example #1
0
    def kill_run(self, message=None):
        _is_killed.set()

        # When initiating kill_run, the api's kill_run sends a signal to all running runs,
        # to change their state to shutdown, which in the end sets it to cancelled.
        # the task which initiated the killing, the current task run, should have a state of Failed, and not Canceled.
        # It is important to set it with an error, to allow the passing of error message, to be displayed in the UI
        # as the error_message for the whole run.
        tr = current_task_run()
        if tr.run == self.run:
            task_run_error = TaskRunError.build_from_message(
                task_run=tr,
                msg=message or DEFAULT_TASK_CANCELED_ERR_MSG,
                help_msg="task with task_run_uid:%s initiated kill_run" %
                (tr.task_run_uid),
                ex_class=DbndCanceledRunError,
            )
            tr.set_task_run_state(TaskRunState.FAILED,
                                  track=True,
                                  error=task_run_error)
        try:
            kill_run(str(self.run.run_uid), ctx=self.run.context)
        except Exception as e:
            raise DatabandFailFastError(
                "Could not send request to kill databand run!", e)
        if tr.run == self.run:
            raise DatabandError(message or DEFAULT_TASK_CANCELED_ERR_MSG)
Example #2
0
    def on_failure(self, luigi_task, exc):
        from dbnd._core.task_run.task_run_error import TaskRunError

        dbnd_task = self.run_manager.get_dbnd_task(luigi_task)
        task_run_error = TaskRunError.build_from_ex(exc, dbnd_task.current_task_run)

        self.run_manager.finish_task(dbnd_task, TaskRunState.FAILED, err=task_run_error)
Example #3
0
    def stop_on_exception(self, type, value, traceback):
        if self._active:
            try:
                error = TaskRunError.build_from_ex(
                    ex=value, task_run=self._task_run, exc_info=(type, value, traceback)
                )
                self._task_run.set_task_run_state(TaskRunState.FAILED, error=error)
            except:
                _handle_tracking_error("dbnd-set-script-error")

        self.stop()
        sys.__excepthook__(type, value, traceback)
Example #4
0
    def run_pod(self,
                task_run: "TaskRun",
                pod: "k8s.V1Pod",
                detach_run: bool = False) -> "DbndPodCtrl":
        kc = self.kube_config
        detach_run = detach_run or kc.detach_run
        if not self.is_possible_to_detach_run():
            detach_run = False

        req = kc.build_kube_pod_req(pod)
        self._attach_live_logs_container(req)

        readable_req_str = readable_pod_request(req)

        if kc.debug:
            logger.info("Pod Creation Request: \n%s", readable_req_str)
            pod_file = task_run.task_run_attempt_file("pod.yaml")
            pod_file.write(readable_req_str)
            logger.debug("Pod Request has been saved to %s", pod_file)

        external_link_dict = self.build_external_links(pod)
        if external_link_dict:
            task_run.set_external_resource_urls(external_link_dict)

        task_run.set_task_run_state(TaskRunState.QUEUED)

        try:
            resp = self.kube_client.create_namespaced_pod(
                body=req, namespace=pod.metadata.namespace)
            logger.info("%s has been submitted at pod '%s' at namespace '%s'" %
                        (task_run, pod.metadata.name, pod.metadata.namespace))
            self.log.debug("Pod Creation Response: %s", resp)
        except ApiException as ex:
            task_run_error = TaskRunError.build_from_ex(ex, task_run)
            task_run.set_task_run_state(TaskRunState.FAILED,
                                        error=task_run_error)
            logger.error(
                "Exception when attempting to create Namespaced Pod using: %s",
                readable_req_str,
            )
            raise

        if detach_run:
            return self

        self.wait()
        return self
Example #5
0
    def dbnd_set_task_pending_fail(self, pod_data, ex):
        metadata = pod_data.metadata

        task_run = _get_task_run_from_pod_data(pod_data)
        if not task_run:
            return
        from dbnd._core.task_run.task_run_error import TaskRunError

        task_run_error = TaskRunError.buid_from_ex(ex, task_run)

        status_log = _get_status_log_safe(pod_data)
        logger.info(
            "Pod '%s' is Pending with exception, marking it as failed. Pod Status:\n%s",
            metadata.name,
            status_log,
        )
        task_run.set_task_run_state(TaskRunState.FAILED, error=task_run_error)
        task_run.tracker.save_task_run_log(status_log)
Example #6
0
    def _process_pod_success(self, submitted_pod):
        task_run = submitted_pod.task_run
        pod_name = submitted_pod.pod_name

        if submitted_pod.processed:
            self.log.info(
                "%s Skipping pod 'success' event from %s: already processed", pod_name
            )
            return
        ti = get_airflow_task_instance(task_run=task_run)

        # we print success message to the screen
        # we will not send it to databand tracking store

        if ti.state == State.SUCCESS:
            dbnd_state = TaskRunState.SUCCESS
        elif ti.state in {State.UP_FOR_RETRY, State.UP_FOR_RESCHEDULE}:
            dbnd_state = TaskRunState.UP_FOR_RETRY
        elif ti.state in {State.FAILED, State.SHUTDOWN}:
            dbnd_state = TaskRunState.FAILED
        else:
            # we got a corruption here:
            error_msg = (
                "Pod %s has finished with SUCCESS, but task instance state is %s, failing the job."
                % (pod_name, ti.state)
            )
            error_help = "Please check pod logs/eviction retry"
            task_run_error = TaskRunError.build_from_message(
                task_run, error_msg, help_msg=error_help
            )
            self._handle_crashed_task_instance(
                failure_reason=PodFailureReason.err_pod_evicted,
                task_run_error=task_run_error,
                task_run=task_run,
            )
            return

        task_run.set_task_run_state(dbnd_state, track=False)
        self.log.info(
            "%s has been completed at pod '%s' with state %s try_number=%s!"
            % (task_run, pod_name, ti.state, ti._try_number)
        )
Example #7
0
    def execute(self,
                airflow_context=None,
                allow_resubmit=True,
                handle_sigterm=True):
        self.task_run.airflow_context = airflow_context
        task_run = self.task_run
        run = task_run.run
        run_config = run.run_config
        task = self.task  # type: Task
        task_engine = task_run.task_engine
        if allow_resubmit and task_engine._should_wrap_with_submit_task(
                task_run):
            args = task_engine.dbnd_executable + [
                "execute",
                "--dbnd-run",
                str(run.driver_dump),
                "task_execute",
                "--task-id",
                task_run.task.task_id,
            ]
            submit_task = self.task_run.task_engine.submit_to_engine_task(
                env=task.task_env,
                task_name=SystemTaskName.task_submit,
                args=args)
            submit_task.task_meta.add_child(task.task_id)
            if run_config.open_web_tracker_in_browser:
                webbrowser.open_new_tab(task_run.task_tracker_url)
            run.run_dynamic_task(submit_task)
            return

        with self.task_run_execution_context(handle_sigterm=handle_sigterm):
            if run.is_killed():
                raise friendly_error.task_execution.databand_context_killed(
                    "task.execute_start of %s" % task)
            try:
                self.task_env.prepare_env()
                if run_config.skip_completed_on_run and task._complete():
                    task_run.set_task_reused()
                    return
                task_run.set_task_run_state(state=TaskRunState.RUNNING)

                if not self.task.ctrl.should_run():
                    self.task.ctrl.validator.find_and_raise_missing_inputs()

                if run_config.validate_task_inputs:
                    self.ctrl.validator.validate_task_inputs()

                try:
                    result = self.task._task_submit()
                    self.ctrl.save_task_band()
                    if run_config.validate_task_outputs:
                        self.ctrl.validator.validate_task_is_complete()
                finally:
                    self.task_run.finished_time = utcnow()

                task_run.set_task_run_state(TaskRunState.SUCCESS)
                run.cleanup_after_task_run(task)

                return result
            except DatabandSigTermError as ex:
                logger.error(
                    "Sig TERM! Killing the task '%s' via task.on_kill()",
                    task_run.task.task_id,
                )
                run._internal_kill()

                error = TaskRunError.build_from_ex(ex, task_run)
                try:
                    task.on_kill()
                except Exception:
                    logger.exception(
                        "Failed to kill task on user keyboard interrupt")
                task_run.set_task_run_state(TaskRunState.CANCELLED,
                                            error=error)
                raise
            except KeyboardInterrupt as ex:
                logger.error("User Interrupt! Killing the task %s",
                             task_run.task.task_id)
                error = TaskRunError.build_from_ex(ex, task_run)
                try:
                    if task._conf_confirm_on_kill_msg:
                        from dbnd._vendor import click

                        if click.confirm(task._conf_confirm_on_kill_msg,
                                         default=True):
                            task.on_kill()
                        else:
                            logger.warning(
                                "Task is not killed accordingly to user input!"
                            )
                    else:
                        task.on_kill()
                except Exception:
                    logger.exception(
                        "Failed to kill task on user keyboard interrupt")
                task_run.set_task_run_state(TaskRunState.CANCELLED,
                                            error=error)
                run._internal_kill()
                raise
            except SystemExit as ex:
                error = TaskRunError.build_from_ex(ex, task_run)
                task_run.set_task_run_state(TaskRunState.CANCELLED,
                                            error=error)
                raise friendly_error.task_execution.system_exit_at_task_run(
                    task, ex)
            except Exception as ex:
                error = TaskRunError.build_from_ex(ex, task_run)
                task_run.set_task_run_state(TaskRunState.FAILED, error=error)
                show_error_once.set_shown(ex)
                raise
            finally:
                task_run.airflow_context = None
Example #8
0
    def run_pod(self, task_run, pod, detach_run=False):
        # type: (TaskRun, Pod, bool) -> DbndPodCtrl
        kc = self.kube_config

        detach_run = detach_run or kc.detach_run
        if kc.show_pod_log:
            logger.info(
                "%s is True, %s will send every docker in blocking mode",
                "show_pod_logs",
                kc.task_name,
            )
            detach_run = False
        if kc.debug:
            logger.info(
                "%s is True, %s will send every docker in blocking mode",
                "debug",
                kc.task_name,
            )
            detach_run = False

        req = kc.build_kube_pod_req(pod)
        readable_req_str = readable_pod_request(req)

        if kc.debug:
            logger.info("Pod Creation Request: \n%s", readable_req_str)
            pod_file = task_run.task_run_attempt_file("pod.yaml")
            pod_file.write(readable_req_str)
            logger.debug("Pod Request has been saved to %s", pod_file)

        dashboard_url = kc.get_dashboard_link(pod)
        pod_log = kc.get_pod_log_link(pod)
        external_link_dict = dict()
        if dashboard_url:
            external_link_dict["k8s_dashboard"] = dashboard_url
        if pod_log:
            external_link_dict["pod_log"] = pod_log
        if external_link_dict:
            task_run.set_external_resource_urls(external_link_dict)
        task_run.set_task_run_state(TaskRunState.QUEUED)

        try:
            resp = self.kube_client.create_namespaced_pod(
                body=req, namespace=pod.namespace
            )
            logger.info(
                "Started pod '%s' in namespace '%s'" % (pod.name, pod.namespace)
            )
            logger.debug("Pod Creation Response: %s", resp)
        except ApiException as ex:
            task_run_error = TaskRunError.build_from_ex(ex, task_run)
            task_run.set_task_run_state(TaskRunState.FAILED, error=task_run_error)
            logger.error(
                "Exception when attempting to create Namespaced Pod using: %s",
                readable_req_str,
            )
            raise
        logging.debug("Kubernetes Job created!")

        # TODO this is pretty dirty.
        #  Better to extract the deploy error checking logic out of the pod launcher and have the watcher
        #   pass an exception through the watcher queue if needed. Current airflow implementation doesn't implement that, so we will stick with the current flow

        if detach_run:
            return self

        self.wait()
        return self
Example #9
0
    def dbnd_set_task_failed(self, pod_data):
        metadata = pod_data.metadata
        # noinspection PyBroadException
        logger.debug("Getting task run")
        task_run = _get_task_run_from_pod_data(pod_data)
        if not task_run:
            logger.info("Can't find a task run for %s", metadata.name)
            return
        if task_run.task_run_state == TaskRunState.FAILED:
            logger.info("Skipping 'failure' event from %s", metadata.name)
            return

        pod_ctrl = self.get_pod_ctrl(metadata.name, metadata.namespace)
        logs = []
        try:
            log_printer = lambda x: logs.append(x)
            pod_ctrl.stream_pod_logs(
                print_func=log_printer, tail_lines=100, follow=False
            )
            pod_ctrl.stream_pod_logs(print_func=log_printer, follow=False)
        except Exception as ex:
            # when deleting pods we get extra failure events so we will have lots of this in the log
            if isinstance(ex, ApiException) and ex.status == 404:
                logger.info(
                    "failed to get log for pod %s: pod not found", metadata.name
                )
            else:
                logger.error("failed to get log for %s: %s", metadata.name, ex)

        try:
            short_log = "\n".join(["out:%s" % l for l in logs[:15]])
        except Exception as ex:
            logger.error(
                "failed to build short log message for %s: %s", metadata.name, ex
            )
            short_log = None

        status_log = _get_status_log_safe(pod_data)

        from dbnd._core.task_run.task_run_error import TaskRunError

        # work around to build an error object
        try:
            err_msg = "Pod %s at %s has failed!" % (metadata.name, metadata.namespace)
            if short_log:
                err_msg += "\nLog:%s" % short_log
            if status_log:
                err_msg += "\nPod Status:%s" % status_log
            raise DatabandError(
                err_msg,
                show_exc_info=False,
                help_msg="Please see full pod log for more details",
            )
        except DatabandError as ex:
            error = TaskRunError.build_from_ex(ex, task_run)

        airflow_task_state = get_airflow_task_instance_state(task_run=task_run)
        logger.debug("task airflow state: %s ", airflow_task_state)
        from airflow.utils.state import State

        if airflow_task_state == State.FAILED:
            # let just notify the error, so we can show it in summary it
            # we will not send it to databand tracking store
            task_run.set_task_run_state(TaskRunState.FAILED, track=False, error=error)
            logger.info(
                "%s",
                task_run.task.ctrl.banner(
                    "Task %s has failed at pod '%s'!"
                    % (task_run.task.task_name, metadata.name),
                    color="red",
                    task_run=task_run,
                ),
            )
        else:
            if airflow_task_state == State.QUEUED:
                # Special case - no airflow code has been run in the pod at all. Must increment try number and send
                # to retry if exit code is matching
                if not pod_ctrl.handle_pod_retry(
                    pod_data, task_run, increment_try_number=True
                ):
                    # No retry was sent
                    task_run.set_task_run_state(
                        TaskRunState.FAILED, track=True, error=error
                    )
            elif airflow_task_state == State.RUNNING:
                # Task was killed unexpectedly -- probably pod failure in K8s - Possible retry attempt
                if not pod_ctrl.handle_pod_retry(pod_data, task_run):
                    # No retry was sent
                    task_run.set_task_run_state(
                        TaskRunState.FAILED, track=True, error=error
                    )
            else:
                task_run.set_task_run_state(
                    TaskRunState.FAILED, track=True, error=error
                )
            if logs:
                task_run.tracker.save_task_run_log("\n".join(logs))
Example #10
0
    def tracking_context(self, call_args, call_kwargs):
        user_code_called = False  # whether we got to executing of user code
        user_code_finished = False  # whether we passed executing of user code
        func_call = None
        try:
            func_call = FuncCallWithResult(
                task_cls=self.get_tracking_task_cls(),
                call_user_code=self.func,
                call_args=tuple(
                    call_args),  # prevent original call_args modification
                call_kwargs=dict(
                    call_kwargs),  # prevent original kwargs modification
            )

            # 1. check that we don't have too many calls
            # 2. Start or reuse existing "inplace_task" that is root for tracked tasks
            if not self._call_count_limit_exceeded(
            ) and _get_or_create_inplace_task():
                cls = func_call.task_cls

                # replace any position argument with kwarg if it possible
                args, kwargs = args_to_kwargs(
                    cls._conf__decorator_spec.args,
                    func_call.call_args,
                    func_call.call_kwargs,
                )

                # instantiate inline task
                task = cls._create_task(args, kwargs)

                # update upstream/downstream relations - needed for correct tracking
                # we can have the task as upstream , as it was executed already
                parent_task = current_task_run().task
                if not parent_task.task_dag.has_upstream(task):
                    parent_task.set_upstream(task)

                # checking if any of the inputs are the outputs of previous task.
                # we can add that task as upstream.
                dbnd_run = get_databand_run()
                call_kwargs_as_targets = dbnd_run.target_origin.get_for_map(
                    kwargs)
                for value_origin in call_kwargs_as_targets.values():
                    up_task = value_origin.origin_target.task
                    task.set_upstream(up_task)

                # creating task_run as a task we found mid-run
                task_run = dbnd_run.create_dynamic_task_run(
                    task, task_engine=current_task_run().task_engine)

                with task_run.runner.task_run_execution_context(
                        handle_sigterm=True):
                    task_run.set_task_run_state(state=TaskRunState.RUNNING)

                    _log_inputs(task_run)

                    # if we reached this line, then all tracking initialization is
                    # finished successfully, and we're going to execute user code
                    user_code_called = True

                    try:
                        # tracking_context is context manager - user code will run on yield
                        yield func_call.set_result

                        # if we reached this line, this means that user code finished
                        # successfully without any exceptions
                        user_code_finished = True
                    except Exception as ex:
                        task_run.finished_time = utcnow()

                        error = TaskRunError.build_from_ex(ex, task_run)
                        task_run.set_task_run_state(TaskRunState.FAILED,
                                                    error=error)
                        raise
                    else:
                        task_run.finished_time = utcnow()

                        # func_call.result should contain result, log it
                        _log_result(task_run, func_call.result)

                        task_run.set_task_run_state(TaskRunState.SUCCESS)
        except Exception:
            if user_code_called and not user_code_finished:
                # if we started to call the user code and not got to user_code_finished
                # line - it means there was user code exception - so just re-raise it
                raise
            # else it's either we didn't reached calling user code, or already passed it
            # then it's some dbnd tracking error - just log it
            if func_call:
                _handle_dynamic_error("tracking-init", func_call)
        # if we didn't reached user_code_called=True line - there was an error during
        # dbnd tracking initialization, so nothing is done - user function wasn't called yet
        if not user_code_called:
            # tracking_context is context manager - user code will run on yield
            yield _passthrough_decorator
Example #11
0
    def _process_pod_failed(
        self,
        submitted_pod: SubmittedPodState,
        known_fail_reason: Optional[PodFailureReason] = None,
    ):
        task_run = submitted_pod.task_run
        pod_name = submitted_pod.pod_name

        task_id = task_run.task_af_id
        ti_state = get_airflow_task_instance_state(task_run=task_run)

        self.log.error("%s: pod %s has crashed, airflow state: %s", task_run,
                       pod_name, ti_state)

        pod_data = self.kube_dbnd.get_pod_status(pod_name)
        pod_ctrl = self.kube_dbnd.get_pod_ctrl(pod_name, self.namespace)

        pod_logs = []
        if pod_data:
            pod_status_log = _get_status_log_safe(pod_data)
            pod_phase = pod_data.status.phase
            if pod_phase != "Pending":
                pod_logs = pod_ctrl.get_pod_logs()
        else:
            pod_status_log = "POD NOT FOUND"

        error_msg = "Pod %s at %s has failed (task state=%s)!" % (
            pod_name,
            self.namespace,
            ti_state,
        )
        failure_reason, failure_message = self._find_pod_failure_reason(
            pod_data=pod_data, pod_name=pod_name)
        if failure_reason:
            error_msg += " Discovered reason for failure is %s: %s." % (
                failure_reason,
                failure_message,
            )
        error_help_msg = "Please see full pod log for more details."
        if pod_logs:
            error_help_msg += "\nPod logs:\n%s\n" % "\n".join(
                ["out: %s" % l for l in pod_logs[-20:]])

        from dbnd._core.task_run.task_run_error import TaskRunError

        task_run_error = TaskRunError.build_from_message(
            task_run=task_run, msg=error_msg, help_msg=error_help_msg)

        if is_task_instance_finished(ti_state):
            # Pod has failed, however, Airflow managed to update the state
            # that means - all code (including dbnd) were executed
            # let just notify the error, so we can show it in the summary
            # we will not send it to databand tracking store, only print to console
            dbnd_state = AIRFLOW_TO_DBND_STATE_MAP.get(ti_state, None)
            task_run.set_task_run_state(dbnd_state,
                                        track=False,
                                        error=task_run_error)

            if dbnd_state == TaskRunState.FAILED:
                color = "red"
            elif dbnd_state == TaskRunState.SUCCESS:
                color = "cyan"
            else:
                color = "yellow"

            self.log.info(
                "%s",
                task_run.task.ctrl.banner(
                    "Task %s(%s) - pod %s has failed, airflow state=%s!" %
                    (task_run.task.task_name, task_id, pod_name, ti_state),
                    color=color,
                    task_run=task_run,
                ),
            )
            return

        # we got State.Failed from watcher, but at DB airflow instance in unfinished state
        # that means the task has failed in the middle
        # (all kind of errors and exit codes)
        task_run_log = error_msg
        task_run_log += pod_status_log
        if pod_logs:
            # let's upload it logs - we don't know what happen
            task_run_log += "\nPod logs:\n\n%s\n\n" % "\n".join(pod_logs)
        task_run.tracker.save_task_run_log(task_run_log)

        self._handle_crashed_task_instance(
            task_run=task_run,
            task_run_error=task_run_error,
            failure_reason=known_fail_reason or failure_reason,
        )
Example #12
0
    def _process_pod_failed(self, submitted_pod):
        task_run = submitted_pod.task_run
        pod_name = submitted_pod.pod_name

        task_id = task_run.task_af_id
        ti_state = get_airflow_task_instance_state(task_run=task_run)

        self.log.info(
            "%s: pod %s has crashed, airflow state: %s", task_run, pod_name, ti_state
        )

        pod_data = self.get_pod_status(pod_name)
        pod_ctrl = self.kube_dbnd.get_pod_ctrl(pod_name, self.namespace)

        pod_logs = []
        if pod_data:
            pod_status_log = _get_status_log_safe(pod_data)
            pod_phase = pod_data.status.phase
            if pod_phase != "Pending":
                pod_logs = pod_ctrl.get_pod_logs()
        else:
            pod_status_log = "POD NOT FOUND"

        error_msg = "Pod %s at %s has failed (task state=%s)!" % (
            pod_name,
            self.namespace,
            ti_state,
        )
        failure_reason, failure_message = self._find_pod_failure_reason(
            task_run=task_run, pod_data=pod_data, pod_name=pod_name
        )
        if failure_reason:
            error_msg += "Found reason for failure: %s - %s." % (
                failure_reason,
                failure_message,
            )
        error_help_msg = "Please see full pod log for more details."
        if pod_logs:
            error_help_msg += "\nPod logs:\n%s\n" % "\n".join(
                ["out: %s" % l for l in pod_logs[-20:]]
            )

        from dbnd._core.task_run.task_run_error import TaskRunError

        task_run_error = TaskRunError.build_from_message(
            task_run=task_run, msg=error_msg, help_msg=error_help_msg,
        )

        if ti_state == State.FAILED:
            # Pod has failed, however, Airfow managed to update the state
            # that means - all code (including dbnd) were executed
            # let just notify the error, so we can show it in summary it
            # we will not send it to databand tracking store
            task_run.set_task_run_state(
                TaskRunState.FAILED, track=False, error=task_run_error
            )
            self.log.info(
                "%s",
                task_run.task.ctrl.banner(
                    "Task %s(%s) - pod %s has failed, airlfow state=Failed!"
                    % (task_run.task.task_name, task_id, pod_name),
                    color="red",
                    task_run=task_run,
                ),
            )
            return True
        # we got State.Failed from watcher, but at DB airflow instance in different state
        # that means the task has failed in the middle
        # (all kind of errors and exit codes)
        task_run_log = error_msg
        task_run_log += pod_status_log
        if pod_logs:
            # let's upload it logs - we don't know what happen
            task_run_log += "\nPod logs:\n\n%s\n\n" % "\n".join(pod_logs)
        task_run.tracker.save_task_run_log(task_run_log)

        self._handle_crashed_task_instance(
            task_run=task_run,
            task_run_error=task_run_error,
            failure_reason=failure_reason,
        )
Example #13
0
def new_execute(context):
    """
    This function replaces the operator's original `execute` function
    """
    # IMPORTANT!!: copied_operator:
    # ---------------------------------------
    # The task (=operator) is copied when airflow enters to TaskInstance._run_raw_task.
    # Then, only the copy_task (=copy_operator) is changed or called (render jinja, signal_handler,
    # pre_execute, execute, etc..).
    copied_operator = context["task_instance"].task

    if not is_dag_eligable_for_tracking(context["task_instance"].dag_id):
        execute = get_execute_function(copied_operator)
        result = execute(copied_operator, context)
        return result

    try:
        # Set that we are in Airflow tracking mode
        get_dbnd_project_config().set_is_airflow_runtime()

        task_context = extract_airflow_context(context)
        # start operator execute run with current airflow context
        task_run = dbnd_airflow_tracking_start(
            airflow_context=task_context)  # type: Optional[TaskRun]

    except Exception as e:
        task_run = None
        logger.error(
            "exception caught while running on dbnd new execute {}".format(e),
            exc_info=True,
        )

    from airflow.exceptions import AirflowRescheduleException

    # running the operator's original execute function
    try:
        with af_tracking_context(task_run, context, copied_operator):
            execute = get_execute_function(copied_operator)
            result = execute(copied_operator, context)

    # Check if this is sensor task that is retrying - normal behavior and not really an exception
    except AirflowRescheduleException:
        dbnd_tracking_stop(finalize_run=False)
        raise
    # catch if the original execute failed
    except Exception as ex:
        if task_run:
            error = TaskRunError.build_from_ex(ex, task_run)
            task_run.set_task_run_state(state=TaskRunState.FAILED, error=error)

        dbnd_tracking_stop()
        raise

    # if we have a task run here we want to log results and xcoms
    if task_run:
        try:
            track_config = AirflowTrackingConfig.from_databand_context()
            if track_config.track_xcom_values:
                # reporting xcoms as metrix of the task
                log_xcom(context, track_config)

            if track_config.track_airflow_execute_result:
                # reporting the result
                log_operator_result(task_run, result, copied_operator,
                                    track_config.track_xcom_values)

        except Exception as e:
            logger.error(
                "exception caught will tracking airflow operator {}".format(e),
                exc_info=True,
            )

    # make sure we close and return the original results
    dbnd_tracking_stop()
    return result
Example #14
0
    def tracking_context(self, call_args, call_kwargs):
        func_call = FuncCall(
            task_cls=self.task_cls,
            call_user_code=self.task_func,
            call_args=tuple(
                call_args),  # prevent original call_args modification
            call_kwargs=dict(
                call_kwargs),  # prevent original call_kwargs modification
        )

        # this will be returned by context manager to store function result
        # for later tracking
        def result_tracker(result):
            func_call.result = result
            return result

        user_code_called = False  # whether we got to executing of user code
        user_code_finished = False  # whether we passed executing of user code
        try:
            if not self._call_count_limit_exceeded(
            ) and _get_or_create_inplace_task():
                task_run = _create_dynamic_task_run(func_call)
                with task_run.runner.task_run_execution_context(
                        handle_sigterm=True):
                    task_run.set_task_run_state(state=TaskRunState.RUNNING)

                    _log_inputs(task_run)

                    # if we reached this line, then all tracking initialization is
                    # finished successfully, and we're going to execute user code
                    user_code_called = True

                    try:
                        # tracking_context is context manager - user code will run on yield
                        yield result_tracker

                        # if we reached this line, this means that user code finished
                        # successfully without any exceptions
                        user_code_finished = True
                    except Exception as ex:
                        task_run.finished_time = utcnow()

                        error = TaskRunError.build_from_ex(ex, task_run)
                        task_run.set_task_run_state(TaskRunState.FAILED,
                                                    error=error)
                        raise
                    else:
                        task_run.finished_time = utcnow()

                        # func_call.result should contain result, log it
                        _log_result(task_run, func_call.result)

                        task_run.set_task_run_state(TaskRunState.SUCCESS)
        except Exception:
            if user_code_called and not user_code_finished:
                # if we started to call the user code and not got to user_code_finished
                # line - it means there was user code exception - so just re-raise it
                raise
            # else it's either we didn't reached calling user code, or already passed it
            # then it's some dbnd tracking error - just log it
            _handle_dynamic_error("tracking-init", func_call)

        # if we didn't reached user_code_called=True line - there was an error during
        # dbnd tracking initialization, so nothing is done - user function wasn't called yet
        if not user_code_called:
            # tracking_context is context manager - user code will run on yield
            yield result_tracker
Example #15
0
    def tracking_context(self, call_args, call_kwargs):
        user_code_called = False  # whether we got to executing of user code
        user_code_finished = False  # whether we passed executing of user code
        func_call = None
        try:
            # 1. check that we don't have too many calls
            if self._call_count_limit_exceeded():
                yield _do_nothing_decorator
                return

            # 2. Start or reuse existing "main tracking task" that is root for tracked tasks
            if not try_get_current_task():
                """
                try to get existing task, and if not exists - try to get/create inplace_task_run
                """
                from dbnd._core.tracking.script_tracking_manager import (
                    try_get_inplace_tracking_task_run, )

                inplace_tacking_task = try_get_inplace_tracking_task_run()
                if not inplace_tacking_task:
                    # we didn't manage to start inplace tracking task run, we will not be able to track
                    yield _do_nothing_decorator
                    return

            tracking_task_definition = self.get_tracking_task_definition()
            callable_spec = tracking_task_definition.task_decorator.get_callable_spec(
            )

            func_call = TrackedFuncCallWithResult(
                callable=self.callable,
                call_args=tuple(
                    call_args),  # prevent original call_args modification
                call_kwargs=dict(
                    call_kwargs),  # prevent original kwargs modification
            )
            # replace any position argument with kwarg if it possible
            args, kwargs = args_to_kwargs(callable_spec.args,
                                          func_call.call_args,
                                          func_call.call_kwargs)

            # instantiate inline task
            task = TrackingTask.for_func(tracking_task_definition, args,
                                         kwargs)

            # update upstream/downstream relations - needed for correct tracking
            # we can have the task as upstream , as it was executed already
            parent_task = current_task_run().task
            if not parent_task.task_dag.has_upstream(task):
                parent_task.set_upstream(task)

            # checking if any of the inputs are the outputs of previous task.
            # we can add that task as upstream.
            dbnd_run = get_databand_run()
            call_kwargs_as_targets = dbnd_run.target_origin.get_for_map(kwargs)
            for value_origin in call_kwargs_as_targets.values():
                up_task = value_origin.origin_target.task
                task.set_upstream(up_task)

            # creating task_run as a task we found mid-run
            task_run = dbnd_run.create_task_run_at_execution_time(
                task, task_engine=current_task_run().task_engine)

            should_capture_log = (
                TrackingConfig.from_databand_context().capture_tracking_log)
            with task_run.runner.task_run_execution_context(
                    handle_sigterm=True, capture_log=should_capture_log):
                task_run.set_task_run_state(state=TaskRunState.RUNNING)

                _log_inputs(task_run)

                # if we reached this line, then all tracking initialization is
                # finished successfully, and we're going to execute user code
                user_code_called = True

                try:
                    # tracking_context is context manager - user code will run on yield
                    yield func_call.set_result

                    # if we reached this line, this means that user code finished
                    # successfully without any exceptions
                    user_code_finished = True
                # We catch BaseException since we want to catch KeyboardInterrupts as well
                except BaseException as ex:
                    task_run.finished_time = utcnow()

                    error = TaskRunError.build_from_ex(ex, task_run)
                    task_run.set_task_run_state(TaskRunState.FAILED,
                                                error=error)
                    raise

                else:
                    task_run.finished_time = utcnow()

                    # func_call.result should contain result, log it
                    _log_result(task_run, func_call.result)

                    task_run.set_task_run_state(TaskRunState.SUCCESS)
        except BaseException:
            if user_code_called and not user_code_finished:
                # if we started to call the user code and not got to user_code_finished
                # line - it means there was user code exception - so just re-raise it
                raise
            # else it's either we didn't reached calling user code, or already passed it
            # then it's some dbnd tracking error - just log it
            if func_call:
                _handle_tracking_error("tracking-init", func_call)
            else:
                log_exception_to_server()
        # if we didn't reached user_code_called=True line - there was an error during
        # dbnd tracking initialization, so nothing is done - user function wasn't called yet
        if not user_code_called:
            # tracking_context is context manager - user code will run on yield
            yield _do_nothing_decorator
            return
Example #16
0
def new_execute(context):
    """
    This function replaces the operator's original `execute` function
    """
    # IMPORTANT!!: copied_operator:
    # ---------------------------------------
    # The task (=operator) is copied when airflow enters to TaskInstance._run_raw_task.
    # Then, only the copy_task (=copy_operator) is changed or called (render jinja, signal_handler,
    # pre_execute, execute, etc..).
    copied_operator = context["task_instance"].task

    try:
        # start operator execute run with current airflow context
        task_context = extract_airflow_context(context)
        task_run = dbnd_run_start(
            airflow_context=task_context
        )  # type: Optional[TaskRun]

        # custom manipulation for each operator
        if task_run:
            tracking_info = get_tracking_information(context, task_run)
            add_tracking_to_submit_task(tracking_info, copied_operator)

    except Exception as e:
        task_run = None
        logger.error(
            "exception caught will running on dbnd new execute {}".format(e),
            exc_info=True,
        )

    # running the operator's original execute function
    try:
        execute = get_execute_function(copied_operator)
        result = execute(copied_operator, context)

    # catch if the original execute failed
    except Exception as ex:
        if task_run:
            error = TaskRunError.build_from_ex(ex, task_run)
            task_run.set_task_run_state(state=TaskRunState.FAILED, error=error)
        dbnd_run_stop()
        raise

    # if we have a task run here we want to log results and xcoms
    if task_run:
        try:
            track_config = AirflowTrackingConfig.current()
            if track_config.track_xcom_values:
                log_xcom(context, track_config)

            if track_config.track_airflow_execute_result:
                log_operator_result(
                    task_run, result, copied_operator, track_config.track_xcom_values
                )

        except Exception as e:
            logger.error(
                "exception caught will tracking airflow operator {}".format(e),
                exc_info=True,
            )

    # make sure we close and return the original results
    dbnd_run_stop()
    return result
Example #17
0
    def run_pod(self, task_run, pod, detach_run=False):
        # type: (TaskRun, Pod, bool) -> DbndPodCtrl
        kc = self.kube_config

        detach_run = detach_run or kc.detach_run
        if kc.show_pod_log:
            logger.info(
                "%s is True, %s will send every docker in blocking mode",
                "show_pod_logs",
                kc.task_name,
            )
            detach_run = False
        if kc.debug:
            logger.info(
                "%s is True, %s will send every docker in blocking mode",
                "debug",
                kc.task_name,
            )
            detach_run = False

        req = kc.build_kube_pod_req(pod)
        readable_req_str = readable_pod_request(req)

        if kc.debug:
            logger.info("Pod Creation Request: \n%s", readable_req_str)
            pod_file = task_run.task_run_attempt_file("pod.yaml")
            pod_file.write(readable_req_str)
            logger.debug("Pod Request has been saved to %s", pod_file)

        dashboard_url = kc.get_dashboard_link(pod)
        pod_log = kc.get_pod_log_link(pod)
        external_link_dict = dict()
        if dashboard_url:
            external_link_dict["k8s_dashboard"] = dashboard_url
        if pod_log:
            external_link_dict["pod_log"] = pod_log
        if external_link_dict:
            task_run.set_external_resource_urls(external_link_dict)
        task_run.set_task_run_state(TaskRunState.QUEUED)

        try:
            resp = self.kube_client.create_namespaced_pod(
                body=req, namespace=pod.namespace)
            logger.info("%s has been submitted at pod '%s' at namespace '%s'" %
                        (task_run, pod.name, pod.namespace))
            self.log.debug("Pod Creation Response: %s", resp)
        except ApiException as ex:
            task_run_error = TaskRunError.build_from_ex(ex, task_run)
            task_run.set_task_run_state(TaskRunState.FAILED,
                                        error=task_run_error)
            logger.error(
                "Exception when attempting to create Namespaced Pod using: %s",
                readable_req_str,
            )
            raise

        if detach_run:
            return self

        self.wait()
        return self