Exemple #1
0
    def __init__(
        self,
        context,
        task_or_task_name,
        run_uid=None,
        scheduled_run_info=None,
        send_heartbeat=True,
        existing_run=None,
        job_name=None,
        source=UpdateSource.dbnd,
        af_context=None,
    ):
        # type:(DatabandContext, Union[Task, str] , Optional[UUID], Optional[ScheduledRunInfo], Optional[bool], Optional[UpdateSource]) -> None
        self.context = context
        s = self.context.settings  # type: DatabandSettings

        if isinstance(task_or_task_name, six.string_types):
            self.root_task_name = task_or_task_name
            self.root_task = None
        elif isinstance(task_or_task_name, Task):
            self.root_task_name = task_or_task_name.task_name
            self.root_task = task_or_task_name
        else:
            raise

        self.job_name = job_name or self.root_task_name

        self.description = s.run.description
        self.is_archived = s.run.is_archived
        self.source = source
        # this was added to allow the scheduler to create the run which will be continued by the actually run command instead of having 2 separate runs
        if not run_uid and DBND_RUN_UID in os.environ:
            # we pop so if this run spawnes subprocesses with their own runs they will be associated using the sub-runs mechanism instead
            # of being fused into this run directly
            run_uid = os.environ.pop(DBND_RUN_UID)
        if run_uid:
            self.run_uid = run_uid
            self.existing_run = True
        else:
            self.run_uid = get_uuid()
            self.existing_run = False

        if existing_run is not None:
            self.existing_run = existing_run

        self.name = s.run.name or get_name_for_uid(self.run_uid)
        # this is so the scheduler can create a run with partial information and then have the subprocess running the actual cmd fill in the details
        self.resubmit_run = (DBND_RESUBMIT_RUN in os.environ
                             and os.environ.pop(DBND_RESUBMIT_RUN) == "true")

        # AIRFLOW, move into executor
        # dag_id , execution_date and run_id is used by airflow
        self.dag_id = AD_HOC_DAG_PREFIX + self.root_task_name
        self.execution_date = unique_execution_date()
        run_id = s.run.id
        if not run_id:
            # we need this name, otherwise Airflow will try to manage our local jobs at scheduler
            # ..zombies cleanup and so on
            run_id = "backfill_{0}_{1}".format(self.name,
                                               self.execution_date.isoformat())
        self.run_id = run_id

        self._template_vars = self._build_template_vars()

        self.is_tracked = True

        self.runtime_errors = []
        self._run_state = None
        self.task_runs = []  # type: List[TaskRun]
        self.task_runs_by_id = {}
        self.task_runs_by_af_id = {}

        self.target_origin = TargetIdentitySourceMap()
        self.describe = DescribeRun(self)
        self.tracker = RunTracker(self,
                                  tracking_store=self.context.tracking_store)

        # ALL RUN CONTEXT SPECIFIC thing
        self.root_run_info = RootRunInfo.from_env(current_run=self)
        self.scheduled_run_info = scheduled_run_info or ScheduledRunInfo.from_env(
            self.run_uid)

        # now we can add driver task
        self.driver_task_run = None  # type: Optional[TaskRun]
        self.root_task_run = None  # type: Optional[TaskRun]

        self.run_folder_prefix = os.path.join(
            "log",
            self.execution_date.strftime("%Y-%m-%d"),
            "%s_%s_%s" % (
                self.execution_date.strftime("%Y-%m-%dT%H%M%S.%f"),
                self.root_task_name,
                self.name,
            ),
        )

        self.run_config = self.context.settings.run  # type: RunConfig
        self.env = env = self.context.env

        self.local_engine = self._get_engine_config(env.local_engine)
        self.remote_engine = self._get_engine_config(env.remote_engine
                                                     or env.local_engine)

        self.submit_driver = (self.run_config.submit_driver
                              if self.run_config.submit_driver is not None else
                              env.submit_driver)
        self.submit_tasks = (self.run_config.submit_tasks
                             if self.run_config.submit_tasks is not None else
                             env.submit_tasks)
        self.task_executor_type, self.parallel = calculate_task_executor_type(
            self.submit_tasks, self.remote_engine, self.context.settings)

        self.sends_heartbeat = send_heartbeat
        self.dynamic_af_tasks_count = dict()
        self.af_context = af_context
        self.start_time = None
        self.finished_time = None
Exemple #2
0
class DatabandRun(SingletonContext):
    def __init__(
        self,
        context,
        task_or_task_name,
        run_uid=None,
        scheduled_run_info=None,
        send_heartbeat=True,
        existing_run=None,
        job_name=None,
        source=UpdateSource.dbnd,
        af_context=None,
    ):
        # type:(DatabandContext, Union[Task, str] , Optional[UUID], Optional[ScheduledRunInfo], Optional[bool], Optional[UpdateSource]) -> None
        self.context = context
        s = self.context.settings  # type: DatabandSettings

        if isinstance(task_or_task_name, six.string_types):
            self.root_task_name = task_or_task_name
            self.root_task = None
        elif isinstance(task_or_task_name, Task):
            self.root_task_name = task_or_task_name.task_name
            self.root_task = task_or_task_name
        else:
            raise

        self.job_name = job_name or self.root_task_name

        self.description = s.run.description
        self.is_archived = s.run.is_archived
        self.source = source
        # this was added to allow the scheduler to create the run which will be continued by the actually run command instead of having 2 separate runs
        if not run_uid and DBND_RUN_UID in os.environ:
            # we pop so if this run spawnes subprocesses with their own runs they will be associated using the sub-runs mechanism instead
            # of being fused into this run directly
            run_uid = os.environ.pop(DBND_RUN_UID)
        if run_uid:
            self.run_uid = run_uid
            self.existing_run = True
        else:
            self.run_uid = get_uuid()
            self.existing_run = False

        if existing_run is not None:
            self.existing_run = existing_run

        self.name = s.run.name or get_name_for_uid(self.run_uid)
        # this is so the scheduler can create a run with partial information and then have the subprocess running the actual cmd fill in the details
        self.resubmit_run = (DBND_RESUBMIT_RUN in os.environ
                             and os.environ.pop(DBND_RESUBMIT_RUN) == "true")

        # AIRFLOW, move into executor
        # dag_id , execution_date and run_id is used by airflow
        self.dag_id = AD_HOC_DAG_PREFIX + self.root_task_name
        self.execution_date = unique_execution_date()
        run_id = s.run.id
        if not run_id:
            # we need this name, otherwise Airflow will try to manage our local jobs at scheduler
            # ..zombies cleanup and so on
            run_id = "backfill_{0}_{1}".format(self.name,
                                               self.execution_date.isoformat())
        self.run_id = run_id

        self._template_vars = self._build_template_vars()

        self.is_tracked = True

        self.runtime_errors = []
        self._run_state = None
        self.task_runs = []  # type: List[TaskRun]
        self.task_runs_by_id = {}
        self.task_runs_by_af_id = {}

        self.target_origin = TargetIdentitySourceMap()
        self.describe = DescribeRun(self)
        self.tracker = RunTracker(self,
                                  tracking_store=self.context.tracking_store)

        # ALL RUN CONTEXT SPECIFIC thing
        self.root_run_info = RootRunInfo.from_env(current_run=self)
        self.scheduled_run_info = scheduled_run_info or ScheduledRunInfo.from_env(
            self.run_uid)

        # now we can add driver task
        self.driver_task_run = None  # type: Optional[TaskRun]
        self.root_task_run = None  # type: Optional[TaskRun]

        self.run_folder_prefix = os.path.join(
            "log",
            self.execution_date.strftime("%Y-%m-%d"),
            "%s_%s_%s" % (
                self.execution_date.strftime("%Y-%m-%dT%H%M%S.%f"),
                self.root_task_name,
                self.name,
            ),
        )

        self.run_config = self.context.settings.run  # type: RunConfig
        self.env = env = self.context.env

        self.local_engine = self._get_engine_config(env.local_engine)
        self.remote_engine = self._get_engine_config(env.remote_engine
                                                     or env.local_engine)

        self.submit_driver = (self.run_config.submit_driver
                              if self.run_config.submit_driver is not None else
                              env.submit_driver)
        self.submit_tasks = (self.run_config.submit_tasks
                             if self.run_config.submit_tasks is not None else
                             env.submit_tasks)
        self.task_executor_type, self.parallel = calculate_task_executor_type(
            self.submit_tasks, self.remote_engine, self.context.settings)

        self.sends_heartbeat = send_heartbeat
        self.dynamic_af_tasks_count = dict()
        self.af_context = af_context
        self.start_time = None
        self.finished_time = None

    def _get_engine_config(self, name):
        # type: ( Union[str, EngineConfig]) -> EngineConfig
        return build_task_from_config(name, EngineConfig)

    @property
    def run_url(self):
        return self.tracker.run_url

    @property
    def task(self):
        return self.root_task

    @property
    def driver_task(self):
        # type: ()->_DbndDriverTask
        return self.driver_task_run.task

    @property
    def driver_dump(self):
        return self.driver_task_run.task.driver_dump

    def _build_template_vars(self):
        # template vars
        ds = self.execution_date.strftime("%Y-%m-%d")
        ts = self.execution_date.isoformat()
        ds_nodash = ds.replace("-", "")
        ts_nodash = ts.replace("-", "").replace(":", "")
        ts_safe = ts.replace(":", "")

        return {
            "run": self,
            "run_ds": ds,
            "run_ts": ts,
            "run_ds_nodash": ds_nodash,
            "run_ts_nodash": ts_nodash,
            "run_ts_safe": ts_safe,
        }

    # TODO: split to get_by_id/by_af_id
    def get_task_run(self, task_id):
        # type: (str) -> TaskRun
        return self.get_task_run_by_id(task_id) or self.get_task_run_by_af_id(
            task_id)

    def get_task_run_by_id(self, task_id):
        # type: (str) -> TaskRun
        return self.task_runs_by_id.get(task_id)

    def get_task_run_by_af_id(self, task_id):
        # type: (str) -> TaskRun
        return self.task_runs_by_af_id.get(task_id)

    def get_af_task_ids(self, task_ids):
        return [self.get_task_run(task_id).task_af_id for task_id in task_ids]

    def get_task(self, task_id):
        # type: (str) -> Task
        return self.get_task_run(task_id).task

    @property
    def describe_dag(self):
        return self.root_task.ctrl.describe_dag

    def set_run_state(self, state):
        self._run_state = state
        self.tracker.set_run_state(state)

    def run_dynamic_task(self, task, task_engine=None):
        if task_engine is None:
            task_engine = self.current_engine_config
        task_run = self.create_dynamic_task_run(task, task_engine)
        task_run.runner.execute()
        return task_run

    def _build_driver_task(self):
        if self.submit_driver and not self.existing_run:
            logger.info("Submitting job to remote execution")
            task_name = SystemTaskName.driver_submit
            is_submitter = True
            is_driver = False
            host_engine = self.local_engine.clone(require_submit=False)
            target_engine = self.local_engine.clone(require_submit=False)
            task_executor_type = TaskExecutorType.local
        else:
            task_name = SystemTaskName.driver
            # even if it's existing run, may be we are running from Airflow
            # so the run is actually "submitted", ( the root airflow job has no info..,
            # we want to capture "real" info of the run
            is_submitter = not self.existing_run or self.resubmit_run
            is_driver = True
            task_executor_type = self.task_executor_type

            if self.submit_driver:
                # submit drive is true, but we are in existing run:
                # we are after the jump from submit to driver execution (to remote engine)
                host_engine = self.remote_engine.clone(require_submit=False)
            else:
                host_engine = self.local_engine.clone(
                    require_submit=False
                )  # we are running at this engine already

            target_engine = self.remote_engine
            if not self.submit_tasks or task_executor_type == "airflow_kubernetes":
                target_engine = target_engine.clone(require_submit=False)

        dbnd_local_root = host_engine.dbnd_local_root or self.env.dbnd_local_root
        run_folder_prefix = self.run_folder_prefix

        local_driver_root = dbnd_local_root.folder(run_folder_prefix)
        local_driver_log = local_driver_root.partition("%s.log" % task_name)

        remote_driver_root = self.env.dbnd_root.folder(run_folder_prefix)
        driver_dump = remote_driver_root.file("%s.pickle" % task_name)

        driver_task = _DbndDriverTask(
            task_name=task_name,
            task_version=self.run_uid,
            execution_date=self.execution_date,
            is_submitter=is_submitter,
            is_driver=is_driver,
            host_engine=host_engine,
            target_engine=target_engine,
            task_executor_type=task_executor_type,
            local_driver_root=local_driver_root,
            local_driver_log=local_driver_log,
            remote_driver_root=remote_driver_root,
            driver_dump=driver_dump,
            sends_heartbeat=self.sends_heartbeat,
        )

        tr = TaskRun(task=driver_task,
                     run=self,
                     task_engine=driver_task.host_engine)
        self._add_task_run(tr)
        return tr

    def _on_enter(self):
        if self.driver_task_run is None:
            # we are in submit/driver
            self.driver_task_run = self._build_driver_task()
            self.current_engine_config = self.driver_task_run.task.host_engine
            self.tracker.init_run()
        else:
            # we are in task run ( after the jump)
            self.current_engine_config = self.driver_task_run.task.target_engine.clone(
                require_submit=False)

    def _dbnd_run_error(self, ex):
        if (
                # what scenario is this aiflow filtering supposed to help with?
                # I had airflow put a default airflow.cfg in .dbnd causing validation error in k8sExecutor which was invisible in the console (only in task log)
            ("airflow" not in ex.__class__.__name__.lower()
             or ex.__class__.__name__ == "AirflowConfigException")
                and "Failed tasks are:" not in str(ex)
                and not isinstance(ex, DatabandRunError)
                and not isinstance(ex, KeyboardInterrupt)
                and not isinstance(ex, DatabandSigTermError)):
            logger.exception(ex)

        if (isinstance(ex, KeyboardInterrupt)
                or isinstance(ex, DatabandSigTermError) or self.is_killed()):
            run_state = RunState.CANCELLED
            unfinished_task_state = TaskRunState.UPSTREAM_FAILED
        elif isinstance(ex, DatabandFailFastError):
            run_state = RunState.FAILED
            unfinished_task_state = TaskRunState.UPSTREAM_FAILED
        else:
            run_state = RunState.FAILED
            unfinished_task_state = TaskRunState.FAILED

        self.set_run_state(run_state)
        self.tracker.tracking_store.set_unfinished_tasks_state(
            run_uid=self.run_uid, state=unfinished_task_state)

        err_banner_msg = self.describe.get_error_banner()
        logger.error(u"\n\n{sep}\n{banner}\n{sep}".format(
            sep=console_utils.ERROR_SEPARATOR, banner=err_banner_msg))
        return DatabandRunError("Run has failed: %s" % ex,
                                run=self,
                                nested_exceptions=ex)

    def run_driver(self):
        """
        Runs the main driver!
        """
        # with captures_log_into_file_as_task_file(log_file=self.local_driver_log.path):
        try:
            self.start_time = utcnow()
            self.driver_task_run.runner.execute()
            self.finished_time = utcnow()
        except DatabandRunError as ex:
            self._dbnd_run_error(ex)
            raise
        except (Exception, KeyboardInterrupt, SystemExit) as ex:
            raise self._dbnd_run_error(ex)
        finally:
            try:
                self.driver_task.host_engine.cleanup_after_run()
            except Exception:
                logger.exception(
                    "Failed to shutdown the current run, continuing")

        return self

    @property
    def duration(self):
        if self.finished_time and self.start_time:
            return self.finished_time - self.start_time
        return None

    def _get_task_by_id(self, task_id):
        task = self.context.task_instance_cache.get_task_by_id(task_id)
        if task is None:
            raise DatabandRuntimeError(
                "Failed to find task %s in current context" % task_id)

        return task

    def save_run(self, target_file=None):
        """
        dumps current run and context to file
        """
        t = target_file or self.driver_dump
        logger.info("Saving current run into %s", t)
        with t.open("wb") as fp:
            cloudpickle.dump(obj=self, file=fp)

    @contextlib.contextmanager
    def run_context(self):
        # type: (DatabandRun) -> Iterator[DatabandRun]

        from dbnd._core.context.databand_context import DatabandContext  # noqa: F811

        with DatabandContext.context(_context=self.context):
            with DatabandRun.context(_context=self) as dr:
                yield dr

    @classmethod
    def load_run(cls, dump_file, disable_tracking_api):
        # type: (FileTarget, bool) -> DatabandRun
        logger.info("Loading dbnd run from %s", dump_file)
        with dump_file.open("rb") as fp:
            databand_run = cloudpickle.load(file=fp)
            if disable_tracking_api:
                databand_run.context.tracking_store.disable_tracking_api()
                logger.info("Tracking has been disabled")
        try:
            if databand_run.context.settings.core.pickle_handler:
                pickle_handler = load_python_callable(
                    databand_run.context.settings.core.pickle_handler)
                pickle_handler(databand_run)
        except Exception as e:
            logger.warning(
                "error while trying to handle pickle with custom handler:", e)
        return databand_run

    def get_template_vars(self):
        return self._template_vars

    def create_dynamic_task_run(self, task, task_engine, task_af_id=None):
        if task_af_id is None:
            task_name = task.friendly_task_name
            if task_name in self.dynamic_af_tasks_count:
                self.dynamic_af_tasks_count[task_name] += 1
                task_af_id = "{}_{}".format(
                    task_name, self.dynamic_af_tasks_count[task_name])
            else:
                self.dynamic_af_tasks_count[task_name] = 1
                task_af_id = task_name

        tr = TaskRun(
            task=task,
            run=self,
            is_dynamic=True,
            task_engine=task_engine,
            task_af_id=task_af_id,
        )
        self.add_task_runs([tr])
        return tr

    def add_task_runs(self, task_runs):
        # type: (List[TaskRun]) -> None
        for tr in task_runs:
            self._add_task_run(tr)

        self.tracker.add_task_runs(task_runs)

    def _add_task_run(self, task_run):
        self.task_runs.append(task_run)
        self.task_runs_by_id[task_run.task.task_id] = task_run
        self.task_runs_by_af_id[task_run.task_af_id] = task_run

        task_run.task.ctrl.last_task_run = task_run

    def cleanup_after_task_run(self, task):
        # type: (Task) -> None
        rels = task.ctrl.relations
        # potentially, all inputs/outputs targets for current task could be removed
        targets_to_clean = set(flatten([rels.task_inputs, rels.task_outputs]))

        targets_in_use = set()
        # any target which appears in inputs of all not finished tasks shouldn't be removed
        for tr in self.task_runs:
            if tr.task_run_state in TaskRunState.final_states():
                continue
            # remove all still needed inputs from targets_to_clean list
            for target in flatten(tr.task.ctrl.relations.task_inputs):
                targets_in_use.add(target)

        TARGET_CACHE.clear_for_targets(targets_to_clean - targets_in_use)

    def get_context_spawn_env(self):
        env = {}
        if has_current_task():
            current = current_task()
        else:
            current = self.root_task

        if current:
            tr = self.get_task_run_by_id(current.task_id)
            if tr:
                env[DBND_PARENT_TASK_RUN_UID] = str(tr.task_run_uid)
                env[DBND_PARENT_TASK_RUN_ATTEMPT_UID] = str(
                    tr.task_run_attempt_uid)

        env[DBND_ROOT_RUN_UID] = str(self.root_run_info.root_run_uid)
        env[DBND_ROOT_RUN_TRACKER_URL] = self.root_run_info.root_run_url

        if self.context.settings.core.user_code_on_fork:
            env[ENV_DBND__USER_PRE_INIT] = self.context.settings.core.user_code_on_fork
        return env

    def _init_without_run(self):
        self.driver_task_run.task.build_root_task_runs(self)

    def is_killed(self):
        return _is_killed.is_set()

    def _internal_kill(self):
        """
        called by TaskRun handler, so we know that run is "canceled"
        otherwise we will get regular exception
        """
        _is_killed.set()

    def kill(self):
        """
        called from user space, kills the current task only
        :return:
        """
        # this is very naive stop implementation
        # in case of simple executor, we'll run task.on_kill code
        _is_killed.set()
        try:
            current_task = None
            from dbnd._core.task_build.task_context import TaskContext, TaskContextPhase

            tc = TaskContext.try_instance()
            if tc.phase == TaskContextPhase.RUN:
                current_list = list(tc.stack)
                if current_list:
                    current_task = current_list.pop()
        except Exception as ex:
            logger.error("Failed to find current task: %s" % ex)
            return

        if not current_task:
            logger.info("No current task.. Killing nothing..")
            return

        try:
            current_task.on_kill()
        except Exception as ex:
            logger.error("Failed to kill current task %s: %s" %
                         (current_task, ex))
            return

    def kill_run(self):
        _is_killed.set()
        try:
            return kill_run(str(self.run_uid), ctx=self.context)
        except Exception as e:
            raise DatabandFailFastError(
                "Could not send request to kill databand run!", e)

    def get_current_dbnd_local_root(self):
        # we should return here the proper engine config, based in which context we run right now
        # it could be submit, driver or task engine
        return self.env.dbnd_local_root