예제 #1
0
    def initialize_band(self):
        try:
            band_context = [
                self.task._auto_load_save_params(auto_read=False,
                                                 normalize_on_change=True)
            ]
            if is_airflow_enabled():
                from dbnd_airflow.dbnd_task_executor.airflow_operators_catcher import (
                    get_databand_op_catcher_dag, )

                band_context.append(get_databand_op_catcher_dag())

            with nested(*band_context):
                band = self.task.band()
                # this one would be normalized
                self.task._task_band_result = band
            self.task_band_result = band  # real value

        except Exception as ex:
            logger.error(
                self.visualiser.banner(
                    msg="Failed to run %s" % _band_call_str(self.task),
                    color="red",
                    exc_info=sys.exc_info(),
                ))

            if self.task._conf__decorator_spec:
                raise
            raise friendly_error.task_build.failed_to_call_band(ex, self.task)
예제 #2
0
파일: task_ctrl.py 프로젝트: kalebinn/dbnd
 def task_context(self, phase):
     # we don't want logs/user wrappers at this stage
     with nested(
             task_context(self.task, phase),
             TaskContextFilter.task_context(self.task.task_id),
     ):
         yield
예제 #3
0
def af_tracking_context(task_run, airflow_context, operator):
    """
    Wrap the execution with handling the environment management
    """
    if not task_run:
        # aborting -  can't enter the context without task_run
        yield
        return

    try:
        tracking_info = get_tracking_information(airflow_context, task_run)
        operator_wrapper = wrap_operator_with_tracking_info(
            tracking_info, operator)

    except Exception as e:
        logger.error(
            "exception caught adding tracking context to operator execution {}"
            "continue without tracking context".format(e),
            exc_info=True,
        )
        yield
        return

    # wrap the execution with tracking info in the environment
    with nested(env(**tracking_info), operator_wrapper):
        yield
예제 #4
0
파일: databand_run.py 프로젝트: Dtchil/dbnd
    def run(self):
        driver_task_run = current_task_run()
        run = driver_task_run.run  # type: DatabandRun
        if self.is_submitter:
            run.set_run_state(RunState.RUNNING)

        ctx = run.context
        ctx.settings.git.validate_git_policy()

        # let prepare for remote execution
        run.remote_engine.prepare_for_run(run)

        task_runs = self.build_root_task_runs(run)

        hearbeat = None

        # right now we run describe in local controller only, but we should do that for more
        if self.is_driver:
            if run.context.settings.system.describe:
                run.describe_dag.describe_dag()
                logger.info(run.describe.run_banner("Described!",
                                                    color="blue"))
                return

            root_task_run = run.root_task_run
            run.root_task.ctrl.banner(
                "Main task '%s' has been created!" % root_task_run.task_af_id,
                color="cyan",
                task_run=root_task_run,
            )

            print_tasks_tree(root_task_run.task, task_runs)

            if self.is_save_run(run, task_runs):
                run.save_run()

            if self.sends_heartbeat:
                hearbeat = start_heartbeat_sender(driver_task_run)

        task_runs_to_run = [tr for tr in task_runs if not tr.is_skipped]

        # create executor without driver task!
        task_executor = get_task_executor(
            run,
            task_executor_type=self.task_executor_type,
            host_engine=self.host_engine,
            target_engine=run.root_task_run.task_engine,
            task_runs=task_runs_to_run,
        )

        with nested(hearbeat):
            task_executor.do_run()

        if self.is_driver:
            # This is great success!
            run.set_run_state(RunState.SUCCESS)
            logger.info(run.describe.run_banner_for_finished())
            return run
        else:
            logger.info(run.describe.run_banner_for_submitted())
예제 #5
0
def send_heartbeat(run_uid, databand_url, heartbeat_interval, driver_pid,
                   tracker, tracker_api):
    from dbnd import config
    from dbnd._core.settings import CoreConfig
    from dbnd._core.task_executor.heartbeat_sender import send_heartbeat_continuously

    with config({
            "core": {
                "tracker": tracker.split(","),
                "tracker_api": tracker_api,
                "databand_url": databand_url,
            }
    }):
        requred_context = []
        if tracker_api == "db":
            from dbnd import new_dbnd_context

            requred_context.append(
                new_dbnd_context(name="send_heartbeat",
                                 autoload_modules=False))

        with nested_context.nested(*requred_context):
            tracking_store = CoreConfig().get_tracking_store()

            send_heartbeat_continuously(run_uid, tracking_store,
                                        heartbeat_interval, driver_pid)
예제 #6
0
    def task_run_execution_context(self,
                                   handle_sigterm=True,
                                   capture_log=True):
        ctx_managers = [
            self.task.ctrl.task_context(phase=TaskContextPhase.RUN)
        ]

        if capture_log:
            ctx_managers.append(self.task_run.log.capture_task_log())

        if handle_sigterm:
            ctx_managers.append(handle_sigterm_at_dbnd_task_run())

        ctx_managers.extend(
            pm.hook.dbnd_task_run_context(task_run=self.task_run))

        set_current_jvm_context(
            self.run.run_uid,
            self.task_run_uid,
            self.task_run_attempt_uid,
            self.task_run.task_af_id,
        )

        with nested(*ctx_managers):
            yield
예제 #7
0
 def task_run_execution_context(self, handle_sigterm=True):
     ctx_managers = [
         self.task.ctrl.task_context(phase=TaskContextPhase.RUN),
         self.task_run.log.capture_task_log(),
     ]
     if handle_sigterm:
         ctx_managers.append(handle_sigterm_at_dbnd_task_run())
     ctx_managers.extend(pm.hook.dbnd_task_run_context(task_run=self.task_run))
     with nested(*ctx_managers):
         yield
예제 #8
0
    def run(self):
        driver_task_run = current_task_run()
        run = driver_task_run.run  # type: DatabandRun
        task_runs = self.build_root_task_runs(run)

        hearbeat = None

        # right now we run describe in local controller only, but we should do that for more
        if self.is_driver:
            if run.context.settings.system.describe:
                run.describe_dag.describe_dag()
                logger.info(run.describe.run_banner("Described!",
                                                    color="blue"))
                return

            root_task_run = run.root_task_run
            run.root_task.ctrl.banner(
                "Main task '%s' has been created!" % root_task_run.task_af_id,
                color="cyan",
                task_run=root_task_run,
            )
            from dbnd._core.task_ctrl.task_dag_describe import DescribeDagCtrl

            completed = {tr.task.task_id: tr.is_reused for tr in task_runs}
            run_describe_dag = DescribeDagCtrl(root_task_run.task,
                                               DescribeFormat.short,
                                               complete_status=completed)
            run_describe_dag.tree_view(describe_format=DescribeFormat.short)

            if self.is_save_run(run, task_runs):
                run.save_run()

            if self.sends_heartbeat:
                hearbeat = start_heartbeat_sender(driver_task_run)

        # create executor without driver task!
        task_executor = get_task_executor(
            run,
            task_executor_type=self.task_executor_type,
            host_engine=self.host_engine,
            target_engine=run.root_task_run.task_engine,
            task_runs=task_runs,
        )

        with nested(hearbeat):
            task_executor.do_run()

        if self.is_driver:
            # This is great success!
            run.set_run_state(RunState.SUCCESS)
            logger.info(run.describe.run_banner_for_finished())
            return run
        else:
            logger.info(run.describe.run_banner_for_submitted())
예제 #9
0
    def initialize_band(self):
        try:
            band_context = []
            if is_airflow_enabled():
                from dbnd_airflow.dbnd_task_executor.airflow_operators_catcher import (
                    get_databand_op_catcher_dag, )

                band_context.append(get_databand_op_catcher_dag())

            original_param_values = []
            for param_value in self.task.task_params.get_param_values(
                    ParameterFilters.OUTPUTS):
                if param_value.name == "task_band" or isinstance(
                        param_value.parameter, FuncResultParameter):
                    continue
                original_param_values.append((param_value, param_value.value))

            with nested(*band_context):
                band = self.task.band()
                # this one would be normalized
                self.task._task_band_result = band
            self.task_band_result = band  # real value

            from dbnd import PipelineTask

            if isinstance(self.task, PipelineTask):
                # after .band has finished, all user outputs of the .band should be defined
                for param_value, _ in original_param_values:
                    # we want to validate only user facing parameters
                    # they should have assigned values by this moment,
                    # pipeline task can not have None outputs, after band call
                    if param_value.parameter.system:
                        continue
                    if is_not_defined(param_value.value):
                        raise friendly_error.task_build.pipeline_task_has_unassigned_outputs(
                            task=self.task, param=param_value.parameter)

            # now let's normalize if user has changed outputs
            for param_value, original_value in original_param_values:
                if param_value.value is original_value:
                    continue

                try:
                    from dbnd._core.utils.task_utils import to_targets

                    normalized_value = to_targets(param_value.value)
                    param_value.update_param_value(normalized_value)
                except Exception as ex:
                    raise friendly_error.task_build.failed_to_assign_param_value_at_band(
                        ex, param_value.parameter, param_value.value,
                        self.task)

        except Exception as ex:
            logger.warning(
                self.visualiser.banner(
                    msg="Failed to run %s" % _band_call_str(self.task),
                    color="red",
                    exc_info=sys.exc_info(),
                ))

            if self.task.task_decorator:
                # just re-raise, we already have an error from the "run" function
                raise

            raise friendly_error.task_build.failed_to_call_band(ex, self.task)
예제 #10
0
    def run_driver(self):
        logger.info("Running driver... Driver PID: %s", os.getpid())

        run = self.run  # type: DatabandRun
        settings = run.context.settings
        run_executor = run.run_executor
        remote_engine = run_executor.remote_engine

        settings.git.validate_git_policy()
        # let prepare for remote execution
        remote_engine.prepare_for_run(run)

        if self.root_task_name_to_build:

            if self.force_task_name:
                kwargs = {"task_name": self.force_task_name}

                logger.info(
                    "Building main task '%s' with name %s",
                    self.root_task_name_to_build,
                    self.force_task_name,
                )
            else:
                logger.info("Building main task '%s'",
                            self.root_task_name_to_build)
                kwargs = {}
            root_task = get_task_registry().build_dbnd_task(
                self.root_task_name_to_build, task_kwargs=kwargs)
            logger.info(
                "Task %s has been created (%s children)",
                root_task.task_id,
                len(root_task.ctrl.task_dag.subdag_tasks()),
            )
            run.root_task = root_task

        # assert that graph is DAG
        run.root_task.task_dag.topological_sort()

        # now we init all task runs for all tasks in the pipeline
        task_runs = self._init_task_runs_for_execution(
            task_engine=remote_engine)
        root_task_run = run.root_task_run
        run.root_task.ctrl.banner(
            "Main task '%s' has been created!" % root_task_run.task_af_id,
            color="cyan",
            task_run=root_task_run,
        )

        if self.run_config.dry:
            run.root_task.ctrl.describe_dag.describe_dag()
            logger.warning(
                "Execution has been stopped due to run.dry=True flag!")
            return run

        print_tasks_tree(root_task_run.task, task_runs)
        if self._is_save_run_pickle(task_runs, remote_engine):
            run_executor.save_run_pickle()

        task_runs_to_run = [tr for tr in task_runs if not tr.is_skipped]

        # THIS IS THE POINT WHEN WE SUBMIT ALL TASKS TO EXECUTION
        # we should make sure that we create executor without driver task
        task_executor = get_task_executor(
            run,
            task_executor_type=run_executor.task_executor_type,
            host_engine=run_executor.host_engine,
            target_engine=remote_engine,
            task_runs=task_runs_to_run,
        )

        hearbeat = None
        if self.send_heartbeat:
            # this will wrap the executor with "heartbeat" process
            hearbeat = start_heartbeat_sender(self)

        with nested(hearbeat):
            task_executor.do_run()

        # We need place the pipeline's task_band in the place we required to by outside configuration
        if settings.run.run_result_json_path:
            new_path = settings.run.run_result_json_path
            try:
                self.result_location.copy(new_path)
            except Exception as e:
                logger.exception(
                    "Couldn't copy the task_band from {old_path} to {new_path}. Failed with this error: {error}"
                    .format(old_path=self.result_location.path,
                            new_path=new_path,
                            error=e))

            else:
                logger.info(
                    "Copied the pipeline's task_band to {new_path}".format(
                        new_path=new_path))

        # if we are in the driver, we want to print banner after executor__task banner
        run.set_run_state(RunState.SUCCESS)

        root_task = self.run.root_task_run.task
        msg = "Your run has been successfully executed!"
        if self.run.duration:
            msg = "Your run has been successfully executed in %s" % self.run.duration
        run_msg = "\n%s\n%s\n" % (
            root_task.ctrl.banner(
                "Main task '%s' is ready!" % root_task.task_name,
                color="green",
                task_run=self.run.root_task_run,
            ),
            run.describe.run_banner(msg, color="green", show_tasks_info=True),
        )
        logger.info(run_msg)

        return run
예제 #11
0
    def run_driver(self):
        logger.info("Running driver... Driver PID: %s", os.getpid())

        run = self.run  # type: DatabandRun
        settings = run.context.settings
        run_executor = run.run_executor
        remote_engine = run_executor.remote_engine

        settings.git.validate_git_policy()
        # let prepare for remote execution
        remote_engine.prepare_for_run(run)

        if self.root_task_name_to_build:
            logger.info("Building main task '%s'",
                        self.root_task_name_to_build)
            root_task = get_task_registry().build_dbnd_task(
                self.root_task_name_to_build)
            logger.info(
                "Task %s has been created (%s children)",
                root_task.task_id,
                len(root_task.ctrl.task_dag.subdag_tasks()),
            )
            run.root_task = root_task

        # assert that graph is DAG
        run.root_task.task_dag.topological_sort()

        # now we init all task runs for all tasks in the pipeline
        task_runs = self._init_task_runs_for_execution(
            task_engine=remote_engine)
        root_task_run = run.root_task_run
        run.root_task.ctrl.banner(
            "Main task '%s' has been created!" % root_task_run.task_af_id,
            color="cyan",
            task_run=root_task_run,
        )

        if self.run_config.dry:
            run.root_task.ctrl.describe_dag.describe_dag()
            logger.warning(
                "Execution has been stopped due to run.dry=True flag!")
            return run

        print_tasks_tree(root_task_run.task, task_runs)
        if self._is_save_run_pickle(task_runs, remote_engine):
            run_executor.save_run_pickle()

        task_runs_to_run = [tr for tr in task_runs if not tr.is_skipped]

        # THIS IS THE POINT WHEN WE SUBMIT ALL TASKS TO EXECUTION
        # we should make sure that we create executor without driver task
        task_executor = get_task_executor(
            run,
            task_executor_type=run_executor.task_executor_type,
            host_engine=run_executor.host_engine,
            target_engine=remote_engine,
            task_runs=task_runs_to_run,
        )

        hearbeat = None
        if self.send_heartbeat:
            # this will wrap the executor with "heartbeat" process
            hearbeat = start_heartbeat_sender(self)

        with nested(hearbeat):
            task_executor.do_run()

        # if we are in the driver, we want to print banner after executor__task banner
        run.set_run_state(RunState.SUCCESS)
        logger.info(run.describe.run_banner_for_finished())
        return run