コード例 #1
0
def build_dbnd_operator_from_taskrun(task_run):
    # type: (TaskRun)-> DbndOperator

    task = task_run.task
    params = convert_to_safe_types(
        {p.name: value
         for p, value in task._params.get_param_values()})
    op_kwargs = task.task_airflow_op_kwargs or {}

    op = DbndOperator(task_id=task_run.task_af_id,
                      dbnd_task_type=task.get_task_family(),
                      dbnd_task_id=task.task_id,
                      params=params,
                      **op_kwargs)

    if task.task_retries is not None:
        op.retries = task.task_retries
        op.retry_delay = task.task_retry_delay

    task.ctrl.airflow_op = op
    set_af_operator_doc_md(task_run, op)
    return op
コード例 #2
0
    def _build_airflow_operator(self, task_cls, call_args, call_kwargs):
        dag = self.dag
        upstream_task_ids = []

        # we support first level xcom values only (not nested)
        def _process_xcom_value(value):
            if isinstance(value, BaseOperator):
                value = build_xcom_str(value)
                upstream_task_ids.append(value.task_id)
            if isinstance(value, XComStr):
                upstream_task_ids.append(value.task_id)
                return target("xcom://%s" % value)
            if self._is_jinja_arg(value):
                return target("jinja://%s" % value)
            return value

        call_kwargs["task_name"] = af_task_id = self.get_normalized_airflow_task_id(
            call_kwargs.pop("task_name", task_cls.get_task_family())
        )
        call_args = [_process_xcom_value(arg) for arg in call_args]
        call_kwargs = {
            name: _process_xcom_value(arg) for name, arg in six.iteritems(call_kwargs)
        }

        task = task_cls(*call_args, **call_kwargs)  # type: Task

        # we will want to not cache "pipelines", as we need to run ".band()" per DAG
        setattr(task, "_dbnd_no_cache", True)

        user_inputs_only = task._params.get_param_values(
            user_only=True, input_only=True
        )
        # take only outputs that are coming from ctror ( based on ParameterValue in task.task_meta
        user_ctor_outputs_only = []
        for p_val in task.task_meta.task_params.values():
            if (
                p_val.parameter.is_output()
                and p_val.source
                and p_val.source.endswith("[ctor]")
            ):
                user_ctor_outputs_only.append((p_val.parameter, p_val.value))

        user_ctor_outputs_only_names = {
            p_val_def.name for p_val_def, p_val in user_ctor_outputs_only
        }
        dbnd_xcom_inputs = []
        dbnd_task_params_fields = []
        non_templated_fields = []
        dbnd_task_params = {}
        for p_def, p_value in user_inputs_only + user_ctor_outputs_only:
            p_name = p_def.name

            dbnd_task_params_fields.append(p_name)
            if isinstance(p_value, FileTarget) and p_value.fs_name == "xcom":
                dbnd_xcom_inputs.append(p_name)
                p_value = p_value.path.replace("xcom://", "")
            elif isinstance(p_value, FileTarget) and p_value.fs_name == "jinja":
                p_value = p_value.path.replace("jinja://", "")
                if p_def.disable_jinja_templating:
                    non_templated_fields.append(p_name)
            dbnd_task_params[p_name] = convert_to_safe_types(p_value)

        single_value_result = False
        if task.task_definition.single_result_output:
            if isinstance(task.result, ResultProxyTarget):
                dbnd_xcom_outputs = task.result.names
            else:
                dbnd_xcom_outputs = ["result"]
                single_value_result = True
        else:
            dbnd_xcom_outputs = [
                p.name
                for p in task._params.get_params(output_only=True, user_only=True)
            ]
        dbnd_xcom_outputs = [n for n in dbnd_xcom_outputs]

        op_kwargs = task.task_airflow_op_kwargs or {}
        allowed_kwargs = signature(BaseOperator.__init__).parameters
        for kwarg in op_kwargs:
            if kwarg not in allowed_kwargs:
                raise AttributeError(
                    "__init__() got an unexpected keyword argument '{}'".format(kwarg)
                )

        op = DbndFunctionalOperator(
            task_id=af_task_id,
            dbnd_task_type=task.get_task_family(),
            dbnd_task_id=task.task_id,
            dbnd_xcom_inputs=dbnd_xcom_inputs,
            dbnd_xcom_outputs=dbnd_xcom_outputs,
            dbnd_overridden_output_params=user_ctor_outputs_only_names,
            dbnd_task_params_fields=dbnd_task_params_fields,
            params=dbnd_task_params,
            **op_kwargs
        )
        # doesn't work in airflow 1_10_0
        op.template_fields = [
            f for f in dbnd_task_params_fields if f not in non_templated_fields
        ]
        task.ctrl.airflow_op = op

        if task.task_retries is not None:
            op.retries = task.task_retries
            op.retry_delay = task.task_retry_delay
        # set_af_operator_doc_md(task_run, op)

        for upstream_task in task.task_dag.upstream:
            upstream_operator = dag.task_dict.get(upstream_task.task_name)
            if not upstream_operator:
                self.__log_task_not_found_error(op, upstream_task.task_name)
                continue
            if not upstream_operator.downstream_task_ids:
                op.set_upstream(upstream_operator)

        for task_id in upstream_task_ids:
            upstream_task = dag.task_dict.get(task_id)
            if not upstream_task:
                self.__log_task_not_found_error(op, task_id)
                continue
            op.set_upstream(upstream_task)

        # populated Operator with current params values
        for k, v in six.iteritems(dbnd_task_params):
            setattr(op, k, v)

        results = [(n, build_xcom_str(op=op, name=n)) for n in dbnd_xcom_outputs]

        for n, xcom_arg in results:
            if n in user_ctor_outputs_only_names:
                continue
            setattr(op, n, xcom_arg)

        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(
                task.ctrl.banner("Created task '%s'." % task.task_name, color="green")
            )
            logger.debug(
                "%s\n\tparams: %s\n\toutputs: %s",
                task.task_id,
                dbnd_task_params,
                results,
            )
        if single_value_result:
            result = results[0]
            return result[1]  # return result XComStr
        return XComResults(result=build_xcom_str(op), sub_results=results)
コード例 #3
0
    def execute(self, context):
        logger.debug("Running dbnd dbnd_task from airflow operator %s", self.task_id)

        dag = context["dag"]
        execution_date = context["execution_date"]
        dag_id = dag.dag_id
        run_uid = get_job_run_uid(dag_id=dag_id, execution_date=execution_date)

        # Airflow has updated all relevant fields in Operator definition with XCom values
        # now we can create a real dbnd dbnd_task with real references to dbnd_task
        new_kwargs = {}
        for p_name in self.dbnd_task_params_fields:
            new_kwargs[p_name] = getattr(self, p_name, None)
            # this is the real input value after
            if p_name in self.dbnd_xcom_inputs:
                new_kwargs[p_name] = target(new_kwargs[p_name])

        new_kwargs["_dbnd_disable_airflow_inplace"] = True
        dag_ctrl = self.get_dbnd_dag_ctrl()
        with DatabandContext.context(_context=dag_ctrl.dbnd_context) as dc:
            logger.debug("Running %s with kwargs=%s ", self.task_id, new_kwargs)
            dbnd_task = dc.task_instance_cache.get_task_by_id(self.dbnd_task_id)
            # rebuild task with new values coming from xcom->operator
            with dbnd_task.ctrl.task_context(phase=TaskContextPhase.BUILD):
                dbnd_task = dbnd_task.clone(
                    output_params_to_clone=self.dbnd_overridden_output_params,
                    **new_kwargs
                )

            logger.debug("Creating inplace databand run for driver dump")
            dag_task = Task(task_name=dag.dag_id, task_target_date=execution_date)
            dag_task.set_upstream(dbnd_task)

            # create databand run
            with new_databand_run(
                context=dc,
                task_or_task_name=dag_task,
                run_uid=run_uid,
                existing_run=False,
                job_name=dag.dag_id,
            ) as dr:  # type: DatabandRun
                dr._init_without_run()

                # dr.driver_task_run.set_task_run_state(state=TaskRunState.RUNNING)
                # "make dag run"
                # dr.root_task_run.set_task_run_state(state=TaskRunState.RUNNING)
                dbnd_task_run = dr.get_task_run_by_id(dbnd_task.task_id)

                needs_databand_run_save = dbnd_task._conf__require_run_dump_file
                if needs_databand_run_save:
                    dr.save_run()

                logger.info(
                    dbnd_task.ctrl.banner(
                        "Running task '%s'." % dbnd_task.task_name, color="cyan"
                    )
                )
                # should be replaced with  tr._execute call
                dbnd_task_run.runner.execute(
                    airflow_context=context, handle_sigterm=False
                )

            logger.debug("Finished to run %s", self)
            result = {
                output_name: convert_to_safe_types(getattr(dbnd_task, output_name))
                for output_name in self.dbnd_xcom_outputs
            }
        return result
コード例 #4
0
    def _build_airflow_operator(self, task_cls, call_args, call_kwargs):
        dag = self.dag
        upstream_task_ids = []

        # we support first level xcom values only (not nested)
        def _process_xcom_value(value):
            if isinstance(value, BaseOperator):
                value = build_xcom_str(value)
                upstream_task_ids.append(value.task_id)
            if isinstance(value, XComStr):
                upstream_task_ids.append(value.task_id)
                return target("xcom://%s" % value)
            return value

        call_kwargs[
            "task_name"] = af_task_id = self.get_normalized_airflow_task_id(
                call_kwargs.pop("task_name", task_cls.get_task_family()))
        call_args = [_process_xcom_value(arg) for arg in call_args]
        call_kwargs = {
            name: _process_xcom_value(arg)
            for name, arg in six.iteritems(call_kwargs)
        }

        task = task_cls(*call_args, **call_kwargs)  # type: Task

        # we will want to not cache "pipelines", as we need to run ".band()" per DAG
        setattr(task, "_dbnd_no_cache", True)

        user_inputs_only = task._params.get_param_values(user_only=True,
                                                         input_only=True)
        dbnd_xcom_inputs = []
        dbnd_task_params_fields = []
        dbnd_task_params = {}
        for p_def, p_value in user_inputs_only:
            p_name = p_def.name

            dbnd_task_params_fields.append(p_name)
            if isinstance(p_value, FileTarget) and p_value.fs_name == "xcom":
                dbnd_xcom_inputs.append(p_name)
                p_value = p_value.path.replace("xcom://", "")
            dbnd_task_params[p_name] = convert_to_safe_types(p_value)

        single_value_result = False
        if task.task_definition.single_result_output:
            if isinstance(task.result, ResultProxyTarget):
                dbnd_xcom_outputs = task.result.names
            else:
                dbnd_xcom_outputs = ["result"]
                single_value_result = True
        else:
            dbnd_xcom_outputs = [
                p.name for p in task._params.get_params(output_only=True,
                                                        user_only=True)
            ]

        op_kwargs = task.task_airflow_op_kwargs or {}

        op = DbndFunctionalOperator(
            task_id=af_task_id,
            dbnd_task_type=task.get_task_family(),
            dbnd_task_id=task.task_id,
            dbnd_xcom_inputs=dbnd_xcom_inputs,
            dbnd_xcom_outputs=dbnd_xcom_outputs,
            dbnd_task_params_fields=dbnd_task_params_fields,
            params=dbnd_task_params,
            **op_kwargs)
        # doesn't work in airflow 1_10_0
        op.template_fields = dbnd_task_params_fields
        task.ctrl.airflow_op = op

        if task.task_retries is not None:
            op.retries = task.task_retries
            op.retry_delay = task.task_retry_delay
        # set_af_operator_doc_md(task_run, op)

        # in case we are inside pipeline, pipeline task will create more operators inside itself.
        for t_child in task.task_meta.children:
            # let's reconnect to all internal tasks
            t_child = self.dbnd_context.task_instance_cache.get_task_by_id(
                t_child)
            upstream_task = dag.task_dict.get(t_child.task_name)
            if not upstream_task:
                self.__log_task_not_found_error(op, t_child.task_name)
                continue
            op.set_upstream(upstream_task)

        for task_id in upstream_task_ids:
            upstream_task = dag.task_dict.get(task_id)
            if not upstream_task:
                self.__log_task_not_found_error(op, task_id)
                continue
            op.set_upstream(upstream_task)

        # populated Operator with current params values
        for k, v in six.iteritems(dbnd_task_params):
            setattr(op, k, v)

        results = [(n, build_xcom_str(op=op, name=n))
                   for n in dbnd_xcom_outputs]
        for n, xcom_arg in results:
            setattr(op, n, xcom_arg)

        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(
                task.ctrl.banner("Created task '%s'." % task.task_name,
                                 color="green"))
            logger.debug(
                "%s\n\tparams: %s\n\toutputs: %s",
                task.task_id,
                dbnd_task_params,
                results,
            )
        if single_value_result:
            result = results[0]
            return result[1]  # return result XComStr
        return XComResults(result=build_xcom_str(op), sub_results=results)