def _process_xcom_value(value): if isinstance(value, BaseOperator): value = build_xcom_str(value) upstream_task_ids.append(value.task_id) if isinstance(value, XComStr): upstream_task_ids.append(value.task_id) return target("xcom://%s" % value) if self._is_jinja_arg(value): return target("jinja://%s" % value) return value
def _build_airflow_operator(self, task_cls, call_args, call_kwargs): dag = self.dag upstream_task_ids = [] # we support first level xcom values only (not nested) def _process_xcom_value(value): if isinstance(value, BaseOperator): value = build_xcom_str(value) upstream_task_ids.append(value.task_id) if isinstance(value, XComStr): upstream_task_ids.append(value.task_id) return target("xcom://%s" % value) if self._is_jinja_arg(value): return target("jinja://%s" % value) return value call_kwargs["task_name"] = af_task_id = self.get_normalized_airflow_task_id( call_kwargs.pop("task_name", task_cls.get_task_family()) ) call_args = [_process_xcom_value(arg) for arg in call_args] call_kwargs = { name: _process_xcom_value(arg) for name, arg in six.iteritems(call_kwargs) } task = task_cls(*call_args, **call_kwargs) # type: Task # we will want to not cache "pipelines", as we need to run ".band()" per DAG setattr(task, "_dbnd_no_cache", True) user_inputs_only = task._params.get_param_values( user_only=True, input_only=True ) # take only outputs that are coming from ctror ( based on ParameterValue in task.task_meta user_ctor_outputs_only = [] for p_val in task.task_meta.task_params.values(): if ( p_val.parameter.is_output() and p_val.source and p_val.source.endswith("[ctor]") ): user_ctor_outputs_only.append((p_val.parameter, p_val.value)) user_ctor_outputs_only_names = { p_val_def.name for p_val_def, p_val in user_ctor_outputs_only } dbnd_xcom_inputs = [] dbnd_task_params_fields = [] non_templated_fields = [] dbnd_task_params = {} for p_def, p_value in user_inputs_only + user_ctor_outputs_only: p_name = p_def.name dbnd_task_params_fields.append(p_name) if isinstance(p_value, FileTarget) and p_value.fs_name == "xcom": dbnd_xcom_inputs.append(p_name) p_value = p_value.path.replace("xcom://", "") elif isinstance(p_value, FileTarget) and p_value.fs_name == "jinja": p_value = p_value.path.replace("jinja://", "") if p_def.disable_jinja_templating: non_templated_fields.append(p_name) dbnd_task_params[p_name] = convert_to_safe_types(p_value) single_value_result = False if task.task_definition.single_result_output: if isinstance(task.result, ResultProxyTarget): dbnd_xcom_outputs = task.result.names else: dbnd_xcom_outputs = ["result"] single_value_result = True else: dbnd_xcom_outputs = [ p.name for p in task._params.get_params(output_only=True, user_only=True) ] dbnd_xcom_outputs = [n for n in dbnd_xcom_outputs] op_kwargs = task.task_airflow_op_kwargs or {} allowed_kwargs = signature(BaseOperator.__init__).parameters for kwarg in op_kwargs: if kwarg not in allowed_kwargs: raise AttributeError( "__init__() got an unexpected keyword argument '{}'".format(kwarg) ) op = DbndFunctionalOperator( task_id=af_task_id, dbnd_task_type=task.get_task_family(), dbnd_task_id=task.task_id, dbnd_xcom_inputs=dbnd_xcom_inputs, dbnd_xcom_outputs=dbnd_xcom_outputs, dbnd_overridden_output_params=user_ctor_outputs_only_names, dbnd_task_params_fields=dbnd_task_params_fields, params=dbnd_task_params, **op_kwargs ) # doesn't work in airflow 1_10_0 op.template_fields = [ f for f in dbnd_task_params_fields if f not in non_templated_fields ] task.ctrl.airflow_op = op if task.task_retries is not None: op.retries = task.task_retries op.retry_delay = task.task_retry_delay # set_af_operator_doc_md(task_run, op) for upstream_task in task.task_dag.upstream: upstream_operator = dag.task_dict.get(upstream_task.task_name) if not upstream_operator: self.__log_task_not_found_error(op, upstream_task.task_name) continue if not upstream_operator.downstream_task_ids: op.set_upstream(upstream_operator) for task_id in upstream_task_ids: upstream_task = dag.task_dict.get(task_id) if not upstream_task: self.__log_task_not_found_error(op, task_id) continue op.set_upstream(upstream_task) # populated Operator with current params values for k, v in six.iteritems(dbnd_task_params): setattr(op, k, v) results = [(n, build_xcom_str(op=op, name=n)) for n in dbnd_xcom_outputs] for n, xcom_arg in results: if n in user_ctor_outputs_only_names: continue setattr(op, n, xcom_arg) if logger.isEnabledFor(logging.DEBUG): logger.debug( task.ctrl.banner("Created task '%s'." % task.task_name, color="green") ) logger.debug( "%s\n\tparams: %s\n\toutputs: %s", task.task_id, dbnd_task_params, results, ) if single_value_result: result = results[0] return result[1] # return result XComStr return XComResults(result=build_xcom_str(op), sub_results=results)