def build_dbnd_operator_from_taskrun(task_run): # type: (TaskRun)-> DbndOperator task = task_run.task params = convert_to_safe_types( {p.name: value for p, value in task._params.get_param_values()}) op_kwargs = task.task_airflow_op_kwargs or {} op = DbndOperator(task_id=task_run.task_af_id, dbnd_task_type=task.get_task_family(), dbnd_task_id=task.task_id, params=params, **op_kwargs) if task.task_retries is not None: op.retries = task.task_retries op.retry_delay = task.task_retry_delay task.ctrl.airflow_op = op set_af_operator_doc_md(task_run, op) return op
def _build_airflow_operator(self, task_cls, call_args, call_kwargs): dag = self.dag upstream_task_ids = [] # we support first level xcom values only (not nested) def _process_xcom_value(value): if isinstance(value, BaseOperator): value = build_xcom_str(value) upstream_task_ids.append(value.task_id) if isinstance(value, XComStr): upstream_task_ids.append(value.task_id) return target("xcom://%s" % value) if self._is_jinja_arg(value): return target("jinja://%s" % value) return value call_kwargs["task_name"] = af_task_id = self.get_normalized_airflow_task_id( call_kwargs.pop("task_name", task_cls.get_task_family()) ) call_args = [_process_xcom_value(arg) for arg in call_args] call_kwargs = { name: _process_xcom_value(arg) for name, arg in six.iteritems(call_kwargs) } task = task_cls(*call_args, **call_kwargs) # type: Task # we will want to not cache "pipelines", as we need to run ".band()" per DAG setattr(task, "_dbnd_no_cache", True) user_inputs_only = task._params.get_param_values( user_only=True, input_only=True ) # take only outputs that are coming from ctror ( based on ParameterValue in task.task_meta user_ctor_outputs_only = [] for p_val in task.task_meta.task_params.values(): if ( p_val.parameter.is_output() and p_val.source and p_val.source.endswith("[ctor]") ): user_ctor_outputs_only.append((p_val.parameter, p_val.value)) user_ctor_outputs_only_names = { p_val_def.name for p_val_def, p_val in user_ctor_outputs_only } dbnd_xcom_inputs = [] dbnd_task_params_fields = [] non_templated_fields = [] dbnd_task_params = {} for p_def, p_value in user_inputs_only + user_ctor_outputs_only: p_name = p_def.name dbnd_task_params_fields.append(p_name) if isinstance(p_value, FileTarget) and p_value.fs_name == "xcom": dbnd_xcom_inputs.append(p_name) p_value = p_value.path.replace("xcom://", "") elif isinstance(p_value, FileTarget) and p_value.fs_name == "jinja": p_value = p_value.path.replace("jinja://", "") if p_def.disable_jinja_templating: non_templated_fields.append(p_name) dbnd_task_params[p_name] = convert_to_safe_types(p_value) single_value_result = False if task.task_definition.single_result_output: if isinstance(task.result, ResultProxyTarget): dbnd_xcom_outputs = task.result.names else: dbnd_xcom_outputs = ["result"] single_value_result = True else: dbnd_xcom_outputs = [ p.name for p in task._params.get_params(output_only=True, user_only=True) ] dbnd_xcom_outputs = [n for n in dbnd_xcom_outputs] op_kwargs = task.task_airflow_op_kwargs or {} allowed_kwargs = signature(BaseOperator.__init__).parameters for kwarg in op_kwargs: if kwarg not in allowed_kwargs: raise AttributeError( "__init__() got an unexpected keyword argument '{}'".format(kwarg) ) op = DbndFunctionalOperator( task_id=af_task_id, dbnd_task_type=task.get_task_family(), dbnd_task_id=task.task_id, dbnd_xcom_inputs=dbnd_xcom_inputs, dbnd_xcom_outputs=dbnd_xcom_outputs, dbnd_overridden_output_params=user_ctor_outputs_only_names, dbnd_task_params_fields=dbnd_task_params_fields, params=dbnd_task_params, **op_kwargs ) # doesn't work in airflow 1_10_0 op.template_fields = [ f for f in dbnd_task_params_fields if f not in non_templated_fields ] task.ctrl.airflow_op = op if task.task_retries is not None: op.retries = task.task_retries op.retry_delay = task.task_retry_delay # set_af_operator_doc_md(task_run, op) for upstream_task in task.task_dag.upstream: upstream_operator = dag.task_dict.get(upstream_task.task_name) if not upstream_operator: self.__log_task_not_found_error(op, upstream_task.task_name) continue if not upstream_operator.downstream_task_ids: op.set_upstream(upstream_operator) for task_id in upstream_task_ids: upstream_task = dag.task_dict.get(task_id) if not upstream_task: self.__log_task_not_found_error(op, task_id) continue op.set_upstream(upstream_task) # populated Operator with current params values for k, v in six.iteritems(dbnd_task_params): setattr(op, k, v) results = [(n, build_xcom_str(op=op, name=n)) for n in dbnd_xcom_outputs] for n, xcom_arg in results: if n in user_ctor_outputs_only_names: continue setattr(op, n, xcom_arg) if logger.isEnabledFor(logging.DEBUG): logger.debug( task.ctrl.banner("Created task '%s'." % task.task_name, color="green") ) logger.debug( "%s\n\tparams: %s\n\toutputs: %s", task.task_id, dbnd_task_params, results, ) if single_value_result: result = results[0] return result[1] # return result XComStr return XComResults(result=build_xcom_str(op), sub_results=results)
def execute(self, context): logger.debug("Running dbnd dbnd_task from airflow operator %s", self.task_id) dag = context["dag"] execution_date = context["execution_date"] dag_id = dag.dag_id run_uid = get_job_run_uid(dag_id=dag_id, execution_date=execution_date) # Airflow has updated all relevant fields in Operator definition with XCom values # now we can create a real dbnd dbnd_task with real references to dbnd_task new_kwargs = {} for p_name in self.dbnd_task_params_fields: new_kwargs[p_name] = getattr(self, p_name, None) # this is the real input value after if p_name in self.dbnd_xcom_inputs: new_kwargs[p_name] = target(new_kwargs[p_name]) new_kwargs["_dbnd_disable_airflow_inplace"] = True dag_ctrl = self.get_dbnd_dag_ctrl() with DatabandContext.context(_context=dag_ctrl.dbnd_context) as dc: logger.debug("Running %s with kwargs=%s ", self.task_id, new_kwargs) dbnd_task = dc.task_instance_cache.get_task_by_id(self.dbnd_task_id) # rebuild task with new values coming from xcom->operator with dbnd_task.ctrl.task_context(phase=TaskContextPhase.BUILD): dbnd_task = dbnd_task.clone( output_params_to_clone=self.dbnd_overridden_output_params, **new_kwargs ) logger.debug("Creating inplace databand run for driver dump") dag_task = Task(task_name=dag.dag_id, task_target_date=execution_date) dag_task.set_upstream(dbnd_task) # create databand run with new_databand_run( context=dc, task_or_task_name=dag_task, run_uid=run_uid, existing_run=False, job_name=dag.dag_id, ) as dr: # type: DatabandRun dr._init_without_run() # dr.driver_task_run.set_task_run_state(state=TaskRunState.RUNNING) # "make dag run" # dr.root_task_run.set_task_run_state(state=TaskRunState.RUNNING) dbnd_task_run = dr.get_task_run_by_id(dbnd_task.task_id) needs_databand_run_save = dbnd_task._conf__require_run_dump_file if needs_databand_run_save: dr.save_run() logger.info( dbnd_task.ctrl.banner( "Running task '%s'." % dbnd_task.task_name, color="cyan" ) ) # should be replaced with tr._execute call dbnd_task_run.runner.execute( airflow_context=context, handle_sigterm=False ) logger.debug("Finished to run %s", self) result = { output_name: convert_to_safe_types(getattr(dbnd_task, output_name)) for output_name in self.dbnd_xcom_outputs } return result
def _build_airflow_operator(self, task_cls, call_args, call_kwargs): dag = self.dag upstream_task_ids = [] # we support first level xcom values only (not nested) def _process_xcom_value(value): if isinstance(value, BaseOperator): value = build_xcom_str(value) upstream_task_ids.append(value.task_id) if isinstance(value, XComStr): upstream_task_ids.append(value.task_id) return target("xcom://%s" % value) return value call_kwargs[ "task_name"] = af_task_id = self.get_normalized_airflow_task_id( call_kwargs.pop("task_name", task_cls.get_task_family())) call_args = [_process_xcom_value(arg) for arg in call_args] call_kwargs = { name: _process_xcom_value(arg) for name, arg in six.iteritems(call_kwargs) } task = task_cls(*call_args, **call_kwargs) # type: Task # we will want to not cache "pipelines", as we need to run ".band()" per DAG setattr(task, "_dbnd_no_cache", True) user_inputs_only = task._params.get_param_values(user_only=True, input_only=True) dbnd_xcom_inputs = [] dbnd_task_params_fields = [] dbnd_task_params = {} for p_def, p_value in user_inputs_only: p_name = p_def.name dbnd_task_params_fields.append(p_name) if isinstance(p_value, FileTarget) and p_value.fs_name == "xcom": dbnd_xcom_inputs.append(p_name) p_value = p_value.path.replace("xcom://", "") dbnd_task_params[p_name] = convert_to_safe_types(p_value) single_value_result = False if task.task_definition.single_result_output: if isinstance(task.result, ResultProxyTarget): dbnd_xcom_outputs = task.result.names else: dbnd_xcom_outputs = ["result"] single_value_result = True else: dbnd_xcom_outputs = [ p.name for p in task._params.get_params(output_only=True, user_only=True) ] op_kwargs = task.task_airflow_op_kwargs or {} op = DbndFunctionalOperator( task_id=af_task_id, dbnd_task_type=task.get_task_family(), dbnd_task_id=task.task_id, dbnd_xcom_inputs=dbnd_xcom_inputs, dbnd_xcom_outputs=dbnd_xcom_outputs, dbnd_task_params_fields=dbnd_task_params_fields, params=dbnd_task_params, **op_kwargs) # doesn't work in airflow 1_10_0 op.template_fields = dbnd_task_params_fields task.ctrl.airflow_op = op if task.task_retries is not None: op.retries = task.task_retries op.retry_delay = task.task_retry_delay # set_af_operator_doc_md(task_run, op) # in case we are inside pipeline, pipeline task will create more operators inside itself. for t_child in task.task_meta.children: # let's reconnect to all internal tasks t_child = self.dbnd_context.task_instance_cache.get_task_by_id( t_child) upstream_task = dag.task_dict.get(t_child.task_name) if not upstream_task: self.__log_task_not_found_error(op, t_child.task_name) continue op.set_upstream(upstream_task) for task_id in upstream_task_ids: upstream_task = dag.task_dict.get(task_id) if not upstream_task: self.__log_task_not_found_error(op, task_id) continue op.set_upstream(upstream_task) # populated Operator with current params values for k, v in six.iteritems(dbnd_task_params): setattr(op, k, v) results = [(n, build_xcom_str(op=op, name=n)) for n in dbnd_xcom_outputs] for n, xcom_arg in results: setattr(op, n, xcom_arg) if logger.isEnabledFor(logging.DEBUG): logger.debug( task.ctrl.banner("Created task '%s'." % task.task_name, color="green")) logger.debug( "%s\n\tparams: %s\n\toutputs: %s", task.task_id, dbnd_task_params, results, ) if single_value_result: result = results[0] return result[1] # return result XComStr return XComResults(result=build_xcom_str(op), sub_results=results)