Ejemplo n.º 1
0
    def stop(self):
        if not self._active:
            return
        self._active = False
        try:
            databand_run = self._run
            root_tr = self._task_run
            root_tr.finished_time = utcnow()

            if root_tr.task_run_state not in TaskRunState.finished_states():
                for tr in databand_run.task_runs:
                    if tr.task_run_state == TaskRunState.FAILED:
                        root_tr.set_task_run_state(
                            TaskRunState.UPSTREAM_FAILED)
                        break
                else:
                    root_tr.set_task_run_state(TaskRunState.SUCCESS)

            driver_tr = databand_run.driver_task.current_task_run
            if driver_tr.task_run_state not in TaskRunState.finished_states():
                driver_tr.set_task_run_state(TaskRunState.SUCCESS)

            if root_tr.task_run_state == TaskRunState.SUCCESS:
                databand_run.set_run_state(RunState.SUCCESS)
            else:
                databand_run.set_run_state(RunState.FAILED)
            logger.info(databand_run.describe.run_banner_for_finished())

            self._close_all_context_managers()
        except Exception as ex:
            _handle_inline_run_error("dbnd-tracking-shutdown")
Ejemplo n.º 2
0
    def stop(self, finalize_run=True):
        if not self._active:
            return
        self._active = False
        try:
            # Required for scripts tracking which do not set the state to SUCCESS
            if finalize_run:
                databand_run = self._run
                root_tr = self._task_run
                root_tr.finished_time = utcnow()

                if root_tr.task_run_state not in TaskRunState.finished_states(
                ):
                    for tr in databand_run.task_runs:
                        if tr.task_run_state == TaskRunState.FAILED:
                            root_tr.set_task_run_state(
                                TaskRunState.UPSTREAM_FAILED)
                            break
                    else:
                        # We can reach here in case of raising exception tracking stand alone python script
                        if sys.exc_info()[1]:
                            root_tr.set_task_run_state(TaskRunState.FAILED)
                        else:
                            root_tr.set_task_run_state(TaskRunState.SUCCESS)

                if root_tr.task_run_state == TaskRunState.SUCCESS:
                    databand_run.set_run_state(RunState.SUCCESS)
                else:
                    databand_run.set_run_state(RunState.FAILED)

            self._close_all_context_managers()

        except Exception:
            _handle_tracking_error("dbnd-tracking-shutdown")
Ejemplo n.º 3
0
    def stop(self):
        if self._stoped:
            return
        try:
            databand_run = self._run
            root_tr = self._task_run
            root_tr.finished_time = utcnow()

            if root_tr.task_run_state not in TaskRunState.finished_states():
                for tr in databand_run.task_runs:
                    if tr.task_run_state == TaskRunState.FAILED:
                        root_tr.set_task_run_state(
                            TaskRunState.UPSTREAM_FAILED)
                        databand_run.set_run_state(RunState.FAILED)
                        break
                else:
                    root_tr.set_task_run_state(TaskRunState.SUCCESS)

            if root_tr.task_run_state == TaskRunState.SUCCESS:
                databand_run.set_run_state(RunState.SUCCESS)
            else:
                databand_run.set_run_state(RunState.FAILED)
            logger.info(databand_run.describe.run_banner_for_finished())

            self._close_all_context_managers()
        except:
            _handle_inline_error("dbnd-tracking-shutdown")
        finally:
            self._stoped = True
Ejemplo n.º 4
0
    def stop(self):
        if not self._active:
            return
        self._active = False
        try:
            databand_run = self._run
            root_tr = self._task_run
            root_tr.finished_time = utcnow()

            if root_tr.task_run_state not in TaskRunState.finished_states():
                for tr in databand_run.task_runs:
                    if tr.task_run_state == TaskRunState.FAILED:
                        root_tr.set_task_run_state(TaskRunState.UPSTREAM_FAILED)
                        break
                else:
                    root_tr.set_task_run_state(TaskRunState.SUCCESS)

            if root_tr.task_run_state == TaskRunState.SUCCESS:
                databand_run.set_run_state(RunState.SUCCESS)
            else:
                databand_run.set_run_state(RunState.FAILED)

            # todo: hard to control the console output if we printing to the console not from the console tracker
            if not CoreConfig.current().silence_tracking_mode:
                logger.info(databand_run.describe.run_banner_for_finished())

            self._close_all_context_managers()

        except Exception as ex:
            _handle_tracking_error("dbnd-tracking-shutdown")
Ejemplo n.º 5
0
    def _terminate_all_running_pods(self):
        """
        Clean up of all running pods on terminate:
        """
        # now we need to clean after the run
        pods_to_delete = sorted(list(self.submitted_pods.values()))
        if not pods_to_delete:
            return

        self.log.info(
            "Terminating run, deleting all %d submitted pods that are still running/not finalized",
            len(pods_to_delete),
        )
        for submitted_pod in pods_to_delete:
            try:
                self.delete_pod(submitted_pod.pod_name)
            except Exception:
                self.log.exception("Failed to terminate pod %s",
                                   submitted_pod.pod_name)

        # Wait for pods to be deleted and execute their own state management
        self.log.info(
            "Setting all running/not finalized pods to cancelled in 10 seconds..."
        )
        time.sleep(10)
        try:
            for submitted_pod in pods_to_delete:
                task_run = submitted_pod.task_run
                ti_state = get_airflow_task_instance_state(task_run)
                if is_task_instance_finished(ti_state):
                    if task_run.task_run_state not in TaskRunState.final_states(
                    ):
                        self.log.info(
                            "%s with pod %s is not finished: airflow state - %s and databand state - %s."
                            "Setting the task_run state to match airflow state",
                            task_run,
                            submitted_pod.pod_name,
                            ti_state,
                            task_run.task_run_state,
                        )

                        new_state = AIRFLOW_TO_DBND_STATE_MAP.get(
                            ti_state, TaskRunState.CANCELLED)
                        task_run.set_task_run_state(new_state)
                    else:
                        self.log.info(
                            "%s with pod %s is finished: airflow state - %s and databand state - %s.Skipping",
                            task_run,
                            submitted_pod.pod_name,
                            ti_state,
                            task_run.task_run_state,
                        )

                    continue
                task_run.set_task_run_state(TaskRunState.CANCELLED)
        except Exception:
            self.log.exception("Could not set pods to cancelled!")
Ejemplo n.º 6
0
    def cleanup_after_task_run(self, task):
        # type: (Task) -> None
        rels = task.ctrl.relations
        # potentially, all inputs/outputs targets for current task could be removed
        targets_to_clean = set(flatten([rels.task_inputs, rels.task_outputs]))

        targets_in_use = set()
        # any target which appears in inputs of all not finished tasks shouldn't be removed
        for tr in self.task_runs:
            if tr.task_run_state in TaskRunState.final_states():
                continue
            # remove all still needed inputs from targets_to_clean list
            for target in flatten(tr.task.ctrl.relations.task_inputs):
                targets_in_use.add(target)

        TARGET_CACHE.clear_for_targets(targets_to_clean - targets_in_use)
Ejemplo n.º 7
0
def _collect_errors(task_runs):
    err = ""
    upstream_failed = []
    failed = []
    for task_run in task_runs:
        task_name = task_run.task.task_name
        if task_run.task_run_state == TaskRunState.UPSTREAM_FAILED:
            # we don't want to show upstream failed in the list
            upstream_failed.append(task_name)
        elif task_run.task_run_state in TaskRunState.direct_fail_states():
            failed.append(task_name)
    if upstream_failed:
        err += "Task that didn't run because of failed dependency:\n\t{}\n".format(
            "\n\t".join(upstream_failed))
    if failed:
        err += "Failed tasks are:\n\t{}".format("\n\t".join(failed))
    return err
Ejemplo n.º 8
0
    def get_error_banner(self):
        # type: ()->str
        err_banners = []
        run = self.run
        err_banners.append(self.run_banner_for_finished())

        failed_task_runs = []
        for task_run in run.task_runs:
            if (
                task_run.last_error
                or task_run.task_run_state in TaskRunState.direct_fail_states()
            ):
                failed_task_runs.append(task_run)

        if len(failed_task_runs) > 1:
            # clear out driver task, we don't want to print it twice
            failed_task_runs = [
                tr
                for tr in failed_task_runs
                if tr.task.task_name not in SystemTaskName.driver_and_submitter
            ]

        for task_run in failed_task_runs:
            if task_run.task_run_state == TaskRunState.CANCELLED:
                msg_header = "Task has been terminated!"
            else:
                msg_header = "Task has failed!"
            msg = task_run.task.ctrl.banner(
                msg=msg_header, color="red", task_run=task_run
            )
            err_banners.append(msg)

        err_banners.append(
            self.run_banner(
                "Your run has failed! See more info above.",
                color="red",
                show_run_info=False,
            )
        )
        return u"\n".join(err_banners)
Ejemplo n.º 9
0
    def _terminate_all_running_pods(self):
        """
        Clean up of all running pods on terminate:
        """
        # now we need to clean after the run
        pods_to_delete = sorted(list(self.submitted_pods.values()))
        if not pods_to_delete:
            return

        self.log.info(
            "Terminating run, deleting all %d submitted pods that are still running/not finalized",
            len(pods_to_delete),
        )
        for submitted_pod in pods_to_delete:
            try:
                self.delete_pod(submitted_pod.pod_name)
            except Exception:
                self.log.exception("Failed to terminate pod %s", submitted_pod.pod_name)

        # Wait for pods to be deleted and execute their own state management
        self.log.info(
            "Setting all running/not finalized pods to cancelled in 10 seconds..."
        )
        time.sleep(10)
        try:
            for submitted_pod in pods_to_delete:
                task_run = submitted_pod.task_run
                if task_run.task_run_state in TaskRunState.final_states():
                    self.log.info(
                        "%s with pod %s was %s, skipping",
                        task_run,
                        submitted_pod.pod_name,
                        task_run.task_run_state,
                    )
                    continue
                task_run.set_task_run_state(TaskRunState.CANCELLED)
        except Exception:
            self.log.exception("Could not set pods to cancelled!")
Ejemplo n.º 10
0
    def do_run(self):
        topological_tasks = topological_sort([tr.task for tr in self.task_runs])
        fail_fast = self.settings.run.fail_fast
        task_failed = False

        task_runs_to_update_state = []
        for task in topological_tasks:
            tr = self.run.get_task_run_by_id(task.task_id)
            if tr.is_reused:
                continue

            if fail_fast and task_failed:
                state = self.run.get_upstream_failed_task_run_state(tr)

                logger.info("Setting %s to %s", task.task_id, state)
                tr.set_task_run_state(state, track=False)
                task_runs_to_update_state.append(tr)
                continue

            upstream_task_runs = [
                self.run.get_task_run_by_id(t.task_id)
                for t in task.ctrl.task_dag.upstream
            ]
            failed_upstream = [
                upstream_task_run
                for upstream_task_run in upstream_task_runs
                if upstream_task_run.task_run_state in TaskRunState.fail_states()
            ]
            if failed_upstream:
                logger.info(
                    "Setting %s to %s", task.task_id, TaskRunState.UPSTREAM_FAILED
                )
                tr.set_task_run_state(TaskRunState.UPSTREAM_FAILED, track=False)
                task_runs_to_update_state.append(tr)
                continue

            if self.run.is_killed():
                logger.info(
                    "Databand Context is killed! Stopping %s to %s",
                    task.task_id,
                    TaskRunState.FAILED,
                )
                tr.set_task_run_state(TaskRunState.FAILED, track=False)
                task_runs_to_update_state.append(tr)
                continue

            logger.debug("Executing task: %s", task.task_id)

            try:
                tr.runner.execute()
            except DatabandSigTermError as e:
                raise e
            except Exception:
                task_failed = True
                logger.exception("Failed to execute task '%s':" % task.task_id)

        if task_runs_to_update_state:
            self.run.tracker.set_task_run_states(task_runs_to_update_state)

        if task_failed:
            err = _collect_errors(self.run.task_runs)

            if err:
                raise DatabandRunError(err)
Ejemplo n.º 11
0
 def __init__(self):
     super(TaskRunStateChoice,
           self).__init__(choices=TaskRunState.all_values(),
                          case_sensitive=False)
Ejemplo n.º 12
0
    def run_banner(self, msg, color="white", show_run_info=False, show_tasks_info=True):
        b = TextBanner(msg, color)
        run = self.run  # type: DatabandRun
        ctx = run.context
        task_run_env = ctx.task_run_env  # type: TaskRunEnvInfo
        driver_task = run.driver_task_run.task
        if show_tasks_info and driver_task.is_driver:
            self._add_tasks_info(b)

        b.column("TRACKER URL", run.run_url, skip_if_empty=True)

        if run.root_run_info.root_run_uid != run.run_uid:
            b.column(
                "ROOT TRACKER URL", run.root_run_info.root_run_url, skip_if_empty=True
            )
            b.column("ROOT UID URL", run.root_run_info.root_run_uid, skip_if_empty=True)

        if run.scheduled_run_info:
            b.column_properties(
                "SCHEDULED",
                [
                    ("scheduled_job", run.scheduled_run_info.scheduled_job_uid),
                    ("scheduled_date", run.scheduled_run_info.scheduled_date),
                    ("dag_run_id", run.scheduled_run_info.scheduled_job_dag_run_id),
                ],
            )

        if show_run_info:
            b.new_line()
            b.column("USER", task_run_env.user)
            b.column(
                "LOG",
                b.f_simple_dict(
                    [
                        ("local", driver_task.local_driver_log),
                        ("remote", driver_task.remote_driver_root),
                    ]
                ),
            )
            b.column("USER CODE VERSION", task_run_env.user_code_version)
            b.column("CMD", task_run_env.cmd_line)
            b.column("RUN UID", "%s" % run.run_uid)
            b.column("DB", self.context.settings.core.sql_conn_repr)
            b.column("ENV", run.env.name)
            b.column(
                "RUN",
                b.f_simple_dict(
                    [
                        ("TASK_EXECUTOR", run.task_executor_type),
                        ("PARALLEL", run.parallel),
                        ("SUBMIT_DRIVER", run.submit_driver),
                        ("SUBMIT_TASKS", run.submit_tasks),
                    ],
                    skip_if_empty=True,
                ),
                skip_if_empty=True,
            )
            if task_run_env.user_data:
                b.column("USER DATA", task_run_env.user_data, skip_if_empty=True)
            b.new_line()

        failed_task_runs = [
            task_run
            for task_run in run.task_runs
            if task_run.task_run_state in TaskRunState.direct_fail_states()
        ]
        if failed_task_runs:
            f_msg = "\n\t".join(tr.task.task_id for tr in failed_task_runs)
            b.column("FAILED", f_msg)

        b.new_line()

        return b.getvalue()
Ejemplo n.º 13
0
    def run_banner(self,
                   msg,
                   color="white",
                   show_run_info=False,
                   show_tasks_info=True):
        b = TextBanner(msg, color)
        run = self.run  # type: DatabandRun
        ctx = run.context
        task_run_env = ctx.task_run_env  # type: TaskRunEnvInfo
        driver_task = run.driver_task_run.task

        orchestration_mode = run.source == UpdateSource.dbnd

        b.column("TRACKER URL", run.run_url, skip_if_empty=True)
        if show_tasks_info and orchestration_mode and driver_task.is_driver:
            self._add_tasks_info(b)

        if run.root_run_info.root_run_uid != run.run_uid:
            b.column("ROOT TRACKER URL",
                     run.root_run_info.root_run_url,
                     skip_if_empty=True)
            b.column("ROOT UID URL",
                     run.root_run_info.root_run_uid,
                     skip_if_empty=True)

        if run.scheduled_run_info:
            b.column_properties(
                "SCHEDULED",
                [
                    ("scheduled_job",
                     run.scheduled_run_info.scheduled_job_uid),
                    ("scheduled_date", run.scheduled_run_info.scheduled_date),
                    ("dag_run_id",
                     run.scheduled_run_info.scheduled_job_dag_run_id),
                ],
            )

        if show_run_info:
            b.new_line()
            run_params = [
                ("user", task_run_env.user),
                ("run_uid", "%s" % run.run_uid),
                ("env", run.env.name),
            ]
            b.column("RUN", b.f_simple_dict(run_params))
            b.column(
                "LOG",
                b.f_simple_dict([
                    ("local", driver_task.local_driver_log),
                    ("remote", driver_task.remote_driver_root),
                ]),
            )
            b.column("USER CODE VERSION", task_run_env.user_code_version)
            b.column("CMD", task_run_env.cmd_line)

            if orchestration_mode:
                if run.context.settings.core.is_db_store_enabled():
                    b.column("DB", self.context.settings.core.sql_conn_repr)
                if run.task_executor_type.startswith("airflow"):
                    assert_airflow_enabled()
                    from dbnd_airflow.db_utils import airflow_sql_conn_repr

                    b.column("Airflow DB", airflow_sql_conn_repr())
                b.column(
                    "EXECUTE",
                    b.f_simple_dict(
                        [
                            ("TASK_EXECUTOR", run.task_executor_type),
                            ("PARALLEL", run.parallel),
                            ("SUBMIT_DRIVER", run.submit_driver),
                            ("SUBMIT_TASKS", run.submit_tasks),
                        ],
                        skip_if_empty=True,
                    ),
                    skip_if_empty=True,
                )
            if task_run_env.user_data and task_run_env.user_data != "None":
                b.column("USER DATA",
                         task_run_env.user_data,
                         skip_if_empty=True)
            b.new_line()

        failed_task_runs = [
            task_run for task_run in run.task_runs
            if task_run.task_run_state in TaskRunState.direct_fail_states()
        ]
        if failed_task_runs:
            f_msg = "\n\t".join(tr.task.task_id for tr in failed_task_runs)
            b.column("FAILED", f_msg)

        b.new_line()

        return b.getvalue()
Ejemplo n.º 14
0
    def run_banner(self, msg, color="white", show_run_info=False, show_tasks_info=True):
        b = TextBanner(msg, color)
        run = self.run  # type: DatabandRun
        ctx = run.context
        task_run_env = ctx.task_run_env  # type: TaskRunEnvInfo

        b.column("TRACKER URL", run.run_url, skip_if_empty=True)
        b.column("TRACKERS", CoreConfig().tracker)
        if run.is_orchestration:
            run_executor = run.run_executor
            driver_task_run = run.driver_task_run

            if (
                show_tasks_info
                and run_executor.run_executor_type == SystemTaskName.driver
            ):
                self._add_tasks_info(b)
            if show_run_info and driver_task_run and driver_task_run.log:
                b.column(
                    "LOG",
                    b.f_simple_dict(
                        [
                            ("local", driver_task_run.log.local_log_file),
                            ("remote", driver_task_run.log.remote_log_file),
                        ],
                        skip_if_empty=True,
                    ),
                )

        if run.root_run_info.root_run_uid != run.run_uid:
            b.column(
                "ROOT TRACKER URL", run.root_run_info.root_run_url, skip_if_empty=True
            )
            b.column("ROOT UID URL", run.root_run_info.root_run_uid, skip_if_empty=True)

        if run.scheduled_run_info:
            b.column_properties(
                "SCHEDULED",
                [
                    ("scheduled_job", run.scheduled_run_info.scheduled_job_uid),
                    ("scheduled_date", run.scheduled_run_info.scheduled_date),
                    ("dag_run_id", run.scheduled_run_info.scheduled_job_dag_run_id),
                ],
            )

        if show_run_info:
            b.new_line()
            run_params = [
                ("user", task_run_env.user),
                ("run_uid", "%s" % run.run_uid),
                ("env", run.env.name),
            ]
            b.column("RUN", b.f_simple_dict(run_params))
            b.column("USER CODE VERSION", task_run_env.user_code_version)
            b.column("CMD", task_run_env.cmd_line)

            if run.is_orchestration:
                run_executor = run.run_executor
                b.column(
                    "EXECUTE",
                    b.f_simple_dict(
                        [
                            ("TASK_EXECUTOR", run_executor.task_executor_type),
                            ("PARALLEL", run_executor.parallel),
                            ("SUBMIT_DRIVER", run_executor.submit_driver),
                            ("SUBMIT_TASKS", run_executor.submit_tasks),
                        ],
                        skip_if_empty=True,
                    ),
                    skip_if_empty=True,
                )
            if task_run_env.user_data and task_run_env.user_data != "None":
                b.column("USER DATA", task_run_env.user_data, skip_if_empty=True)
            b.new_line()

        failed_task_runs = [
            task_run
            for task_run in run.task_runs
            if task_run.task_run_state in TaskRunState.direct_fail_states()
        ]
        if failed_task_runs:
            f_msg = "\n\t".join(tr.task.task_id for tr in failed_task_runs)
            b.column("FAILED", f_msg)

        if run.root_task_run and run.is_orchestration:
            b.column("TASK_BAND", run.root_task_run.task.task_band)

        b.new_line()

        return b.getvalue()