def _get_unique_task_id(task_id: str, dag: Optional[DAG] = None) -> str: """ Generate unique task id given a DAG (or if run in a DAG context) Ids are generated by appending a unique number to the end of the original task id. Example: task_id task_id__1 task_id__2 ... task_id__20 """ dag = dag or DagContext.get_current_dag() if not dag or task_id not in dag.task_ids: return task_id core = re.split(r'__\d+$', task_id)[0] suffixes = sorted( [ int(re.split(r'^.+__', task_id)[1]) for task_id in dag.task_ids if re.match(rf'^{core}__\d+$', task_id) ] ) if not suffixes: return f'{core}__1' return f'{core}__{suffixes[-1] + 1}'
def get_unique_task_id( task_id: str, dag: Optional[DAG] = None, task_group: Optional[TaskGroup] = None ) -> str: """ Generate unique task id given a DAG (or if run in a DAG context) Ids are generated by appending a unique number to the end of the original task id. Example: task_id task_id__1 task_id__2 ... task_id__20 """ dag = dag or DagContext.get_current_dag() if not dag: return task_id # We need to check if we are in the context of TaskGroup as the task_id may # already be altered task_group = task_group or TaskGroupContext.get_current_task_group(dag) tg_task_id = task_group.child_id(task_id) if task_group else task_id if tg_task_id not in dag.task_ids: return task_id core = re.split(r'__\d+$', task_id)[0] suffixes = sorted( int(re.split(r'^.+__', task_id)[1]) for task_id in dag.task_ids if re.match(rf'^{core}__\d+$', task_id) ) if not suffixes: return f'{core}__1' return f'{core}__{suffixes[-1] + 1}'
def operator_constructor(**kwargs): defaults['task_id'] = kwargs.pop('task_id', None) or defaults.get('task_id') or python_callable.__name__ defaults['pool'] = kwargs.pop('pool', None) or defaults.get('pool') try: cmdag = settings.CONTEXT_MANAGER_DAG except AttributeError: # Airflow 2.0+ from airflow.models.dag import DagContext cmdag = DagContext.get_current_dag() dag = kwargs.get('dag', None) or defaults.get('dag', None) or cmdag dag_args = copy(dag.default_args) if dag else {} dag_params = copy(dag.params) if dag else {} default_args = {} if 'default_args' in defaults: default_args = defaults['default_args'] if 'params' in default_args: dag_params.update(default_args['params']) del default_args['params'] dag_args.update(default_args) default_args = dag_args for arg in signature.parameters: if arg not in kwargs and arg in default_args: kwargs[arg] = default_args[arg] return PythonOperator( python_callable=python_callable, op_kwargs=kwargs, params=dag_params, **defaults, )
def map(self, **kwargs: "MapArgument") -> XComArg: self._validate_arg_names("map", kwargs) partial_kwargs = self.kwargs.copy() dag = partial_kwargs.pop("dag", DagContext.get_current_dag()) task_group = partial_kwargs.pop( "task_group", TaskGroupContext.get_current_task_group(dag)) task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group) params = partial_kwargs.pop("params", None) # Logic here should be kept in sync with BaseOperatorMeta.partial(). if "task_concurrency" in partial_kwargs: raise TypeError("unexpected argument: task_concurrency") if partial_kwargs.get("wait_for_downstream"): partial_kwargs["depends_on_past"] = True start_date = timezone.convert_to_utc( partial_kwargs.pop("start_date", None)) end_date = timezone.convert_to_utc(partial_kwargs.pop( "end_date", None)) if partial_kwargs.get("pool") is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME partial_kwargs["retries"] = parse_retries( partial_kwargs.get("retries", DEFAULT_RETRIES)) partial_kwargs["retry_delay"] = coerce_retry_delay( partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY), ) partial_kwargs["resources"] = coerce_resources( partial_kwargs.get("resources")) partial_kwargs.setdefault("executor_config", {}) partial_kwargs.setdefault("op_args", []) partial_kwargs.setdefault("op_kwargs", {}) # Mypy does not work well with a subclassed attrs class :( _MappedOperator = cast(Any, DecoratedMappedOperator) operator = _MappedOperator( operator_class=self.operator_class, mapped_kwargs={}, partial_kwargs=partial_kwargs, task_id=task_id, params=params, deps=MappedOperator.deps_for(self.operator_class), operator_extra_links=self.operator_class.operator_extra_links, template_ext=self.operator_class.template_ext, template_fields=self.operator_class.template_fields, ui_color=self.operator_class.ui_color, ui_fgcolor=self.operator_class.ui_fgcolor, is_dummy=False, task_module=self.operator_class.__module__, task_type=self.operator_class.__name__, dag=dag, task_group=task_group, start_date=start_date, end_date=end_date, multiple_outputs=self.multiple_outputs, python_callable=self.function, mapped_op_kwargs=kwargs, ) return XComArg(operator=operator)
def __init__( self, group_id: Optional[str], prefix_group_id: bool = True, parent_group: Optional["TaskGroup"] = None, dag: Optional["DAG"] = None, tooltip: str = "", ui_color: str = "CornflowerBlue", ui_fgcolor: str = "#000", ): from airflow.models.dag import DagContext self.prefix_group_id = prefix_group_id if group_id is None: # This creates a root TaskGroup. if parent_group: raise AirflowException("Root TaskGroup cannot have parent_group") # used_group_ids is shared across all TaskGroups in the same DAG to keep track # of used group_id to avoid duplication. self.used_group_ids: Set[Optional[str]] = set() self._parent_group = None else: if not isinstance(group_id, str): raise ValueError("group_id must be str") if not group_id: raise ValueError("group_id must not be empty") dag = dag or DagContext.get_current_dag() if not parent_group and not dag: raise AirflowException("TaskGroup can only be used inside a dag") self._parent_group = parent_group or TaskGroupContext.get_current_task_group(dag) if not self._parent_group: raise AirflowException("TaskGroup must have a parent_group except for the root TaskGroup") self.used_group_ids = self._parent_group.used_group_ids self._group_id = group_id if self.group_id in self.used_group_ids: raise DuplicateTaskIdFound(f"group_id '{self.group_id}' has already been added to the DAG") self.used_group_ids.add(self.group_id) self.used_group_ids.add(self.downstream_join_id) self.used_group_ids.add(self.upstream_join_id) self.children: Dict[str, Union["BaseOperator", "TaskGroup"]] = {} if self._parent_group: self._parent_group.add(self) self.tooltip = tooltip self.ui_color = ui_color self.ui_fgcolor = ui_fgcolor # Keep track of TaskGroups or tasks that depend on this entire TaskGroup separately # so that we can optimize the number of edges when entire TaskGroups depend on each other. self.upstream_group_ids: Set[Optional[str]] = set() self.downstream_group_ids: Set[Optional[str]] = set() self.upstream_task_ids: Set[Optional[str]] = set() self.downstream_task_ids: Set[Optional[str]] = set()
def get_current_task_group(cls, dag: Optional["DAG"]) -> Optional[TaskGroup]: """Get the current TaskGroup.""" from airflow.models.dag import DagContext if not cls._context_managed_task_group: dag = dag or DagContext.get_current_dag() if dag: # If there's currently a DAG but no TaskGroup, return the root TaskGroup of the dag. return dag.task_group return cls._context_managed_task_group
def _validate_dag(self, kwargs): dag = kwargs.get('dag') or DagContext.get_current_dag() if not dag: raise AirflowException('Please pass in the `dag` param or call within a DAG context manager') if dag.dag_id + '.' + kwargs['task_id'] != self.subdag.dag_id: raise AirflowException( "The subdag's dag_id should have the form '{{parent_dag_id}}.{{this_task_id}}'. " "Expected '{d}.{t}'; received '{rcvd}'.".format( d=dag.dag_id, t=kwargs['task_id'], rcvd=self.subdag.dag_id) )
def map( self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs ) -> XComArg: self._validate_arg_names("map", kwargs) dag = dag or DagContext.get_current_dag() task_group = task_group or TaskGroupContext.get_current_task_group(dag) task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group) operator = MappedOperator.from_decorator( decorator=self, dag=dag, task_group=task_group, task_id=task_id, mapped_kwargs=kwargs, ) return XComArg(operator=operator)
def safe_get_context_manager_dag(): """ Try to find the CONTEXT_MANAGER_DAG object inside airflow. It was moved between versions, so we look for it in all the hiding places that we know of. """ if AIRFLOW_VERSION_2: from airflow.models.dag import DagContext return DagContext.get_current_dag() if hasattr(settings, "CONTEXT_MANAGER_DAG"): return settings.CONTEXT_MANAGER_DAG elif hasattr(DAG, "_CONTEXT_MANAGER_DAG"): return DAG._CONTEXT_MANAGER_DAG elif hasattr(models, "_CONTEXT_MANAGER_DAG"): return models._CONTEXT_MANAGER_DAG return None
def get_unique_task_id( task_id: str, dag: Optional[DAG] = None, task_group: Optional[TaskGroup] = None, ) -> str: """ Generate unique task id given a DAG (or if run in a DAG context) Ids are generated by appending a unique number to the end of the original task id. Example: task_id task_id__1 task_id__2 ... task_id__20 """ dag = dag or DagContext.get_current_dag() if not dag: return task_id # We need to check if we are in the context of TaskGroup as the task_id may # already be altered task_group = task_group or TaskGroupContext.get_current_task_group(dag) tg_task_id = task_group.child_id(task_id) if task_group else task_id if tg_task_id not in dag.task_ids: return task_id def _find_id_suffixes(dag: DAG) -> Iterator[int]: prefix = re.split(r"__\d+$", tg_task_id)[0] for task_id in dag.task_ids: match = re.match(rf"^{prefix}__(\d+)$", task_id) if match is None: continue yield int(match.group(1)) yield 0 # Default if there's no matching task ID. core = re.split(r"__\d+$", task_id)[0] return f"{core}__{max(_find_id_suffixes(dag)) + 1}"
def wrapper(*args: Any, **kwargs: Any) -> Any: from airflow.models.dag import DagContext if len(args) > 1: raise AirflowException( "Use keyword arguments when initializing operators") dag_args: Dict[str, Any] = {} dag_params: Dict[str, Any] = {} dag = kwargs.get('dag') or DagContext.get_current_dag() if dag: dag_args = copy(dag.default_args) or {} dag_params = copy(dag.params) or {} params = kwargs.get('params', {}) or {} dag_params.update(params) default_args = {} if 'default_args' in kwargs: default_args = kwargs['default_args'] if 'params' in default_args: dag_params.update(default_args['params']) del default_args['params'] dag_args.update(default_args) default_args = dag_args for arg in sig_cache.parameters: if arg not in kwargs and arg in default_args: kwargs[arg] = default_args[arg] missing_args = list(non_optional_args - set(kwargs)) if missing_args: msg = f"Argument {missing_args} is required" raise AirflowException(msg) kwargs['params'] = dag_params result = func(*args, **kwargs) return result
def map(self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs) -> XComArg: dag = dag or DagContext.get_current_dag() task_group = task_group or TaskGroupContext.get_current_task_group(dag) task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group) self._validate_arg_names("map", kwargs) operator = MappedOperator( operator_class=self.operator_class, task_id=task_id, dag=dag, task_group=task_group, partial_kwargs=self.kwargs, # Set them to empty to bypass the validation, as for decorated stuff we validate ourselves mapped_kwargs={}, ) operator.mapped_kwargs.update(kwargs) return XComArg(operator=operator)
def __init__( self, task_id: str, owner: str = conf.get('operators', 'DEFAULT_OWNER'), email: Optional[Union[str, Iterable[str]]] = None, email_on_retry: bool = True, email_on_failure: bool = True, retries: Optional[int] = conf.getint('core', 'default_task_retries', fallback=0), retry_delay: timedelta = timedelta(seconds=300), retry_exponential_backoff: bool = False, max_retry_delay: Optional[datetime] = None, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, depends_on_past: bool = False, wait_for_downstream: bool = False, dag=None, params: Optional[Dict] = None, default_args: Optional[Dict] = None, # pylint: disable=unused-argument priority_weight: int = 1, weight_rule: str = WeightRule.DOWNSTREAM, queue: str = conf.get('celery', 'default_queue'), pool: str = Pool.DEFAULT_POOL_NAME, sla: Optional[timedelta] = None, execution_timeout: Optional[timedelta] = None, on_execute_callback: Optional[Callable] = None, on_failure_callback: Optional[Callable] = None, on_success_callback: Optional[Callable] = None, on_retry_callback: Optional[Callable] = None, trigger_rule: str = TriggerRule.ALL_SUCCESS, resources: Optional[Dict] = None, run_as_user: Optional[str] = None, task_concurrency: Optional[int] = None, executor_config: Optional[Dict] = None, do_xcom_push: bool = True, inlets: Optional[Any] = None, outlets: Optional[Any] = None, *args, **kwargs): from airflow.models.dag import DagContext super().__init__() if args or kwargs: if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'): raise AirflowException( "Invalid arguments were passed to {c} (task_id: {t}). Invalid " "arguments were:\n*args: {a}\n**kwargs: {k}".format( c=self.__class__.__name__, a=args, k=kwargs, t=task_id), ) warnings.warn( 'Invalid arguments were passed to {c} (task_id: {t}). ' 'Support for passing such arguments will be dropped in ' 'future. Invalid arguments were:' '\n*args: {a}\n**kwargs: {k}'.format(c=self.__class__.__name__, a=args, k=kwargs, t=task_id), category=PendingDeprecationWarning, stacklevel=3) validate_key(task_id) self.task_id = task_id self.owner = owner self.email = email self.email_on_retry = email_on_retry self.email_on_failure = email_on_failure self.start_date = start_date if start_date and not isinstance(start_date, datetime): self.log.warning("start_date for %s isn't datetime.datetime", self) elif start_date: self.start_date = timezone.convert_to_utc(start_date) self.end_date = end_date if end_date: self.end_date = timezone.convert_to_utc(end_date) if not TriggerRule.is_valid(trigger_rule): raise AirflowException( "The trigger_rule must be one of {all_triggers}," "'{d}.{t}'; received '{tr}'.".format( all_triggers=TriggerRule.all_triggers(), d=dag.dag_id if dag else "", t=task_id, tr=trigger_rule)) self.trigger_rule = trigger_rule self.depends_on_past = depends_on_past self.wait_for_downstream = wait_for_downstream if wait_for_downstream: self.depends_on_past = True self.retries = retries self.queue = queue self.pool = pool self.sla = sla self.execution_timeout = execution_timeout self.on_execute_callback = on_execute_callback self.on_failure_callback = on_failure_callback self.on_success_callback = on_success_callback self.on_retry_callback = on_retry_callback if isinstance(retry_delay, timedelta): self.retry_delay = retry_delay else: self.log.debug("Retry_delay isn't timedelta object, assuming secs") # noinspection PyTypeChecker self.retry_delay = timedelta(seconds=retry_delay) self.retry_exponential_backoff = retry_exponential_backoff self.max_retry_delay = max_retry_delay self.params = params or {} # Available in templates! self.priority_weight = priority_weight if not WeightRule.is_valid(weight_rule): raise AirflowException( "The weight_rule must be one of {all_weight_rules}," "'{d}.{t}'; received '{tr}'.".format( all_weight_rules=WeightRule.all_weight_rules, d=dag.dag_id if dag else "", t=task_id, tr=weight_rule)) self.weight_rule = weight_rule self.resources: Optional[Resources] = Resources( **resources) if resources else None self.run_as_user = run_as_user self.task_concurrency = task_concurrency self.executor_config = executor_config or {} self.do_xcom_push = do_xcom_push # Private attributes self._upstream_task_ids: Set[str] = set() self._downstream_task_ids: Set[str] = set() self._dag = None self.dag = dag or DagContext.get_current_dag() # subdag parameter is only set for SubDagOperator. # Setting it to None by default as other Operators do not have that field from airflow.models.dag import DAG self.subdag: Optional[DAG] = None self._log = logging.getLogger("airflow.task.operators") # Lineage self.inlets: List = [] self.outlets: List = [] self._inlets: List = [] self._outlets: List = [] if inlets: self._inlets = inlets if isinstance(inlets, list) else [ inlets, ] if outlets: self._outlets = outlets if isinstance(outlets, list) else [ outlets, ]
def __init__( self, group_id: Optional[str], prefix_group_id: bool = True, parent_group: Optional["TaskGroup"] = None, dag: Optional["DAG"] = None, default_args: Optional[Dict] = None, tooltip: str = "", ui_color: str = "CornflowerBlue", ui_fgcolor: str = "#000", add_suffix_on_collision: bool = False, ): from airflow.models.dag import DagContext self.prefix_group_id = prefix_group_id self.default_args = copy.deepcopy(default_args or {}) dag = dag or DagContext.get_current_dag() if group_id is None: # This creates a root TaskGroup. if parent_group: raise AirflowException( "Root TaskGroup cannot have parent_group") # used_group_ids is shared across all TaskGroups in the same DAG to keep track # of used group_id to avoid duplication. self.used_group_ids = set() self.dag = dag else: if prefix_group_id: # If group id is used as prefix, it should not contain spaces nor dots # because it is used as prefix in the task_id validate_group_key(group_id) else: if not isinstance(group_id, str): raise ValueError("group_id must be str") if not group_id: raise ValueError("group_id must not be empty") if not parent_group and not dag: raise AirflowException( "TaskGroup can only be used inside a dag") parent_group = parent_group or TaskGroupContext.get_current_task_group( dag) if not parent_group: raise AirflowException( "TaskGroup must have a parent_group except for the root TaskGroup" ) if dag is not parent_group.dag: raise RuntimeError( "Cannot mix TaskGroups from different DAGs: %s and %s", dag, parent_group.dag) self.used_group_ids = parent_group.used_group_ids # if given group_id already used assign suffix by incrementing largest used suffix integer # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3 self._group_id = group_id self._check_for_group_id_collisions(add_suffix_on_collision) self.children: Dict[str, DAGNode] = {} if parent_group: parent_group.add(self) self.used_group_ids.add(self.group_id) if self.group_id: self.used_group_ids.add(self.downstream_join_id) self.used_group_ids.add(self.upstream_join_id) self.tooltip = tooltip self.ui_color = ui_color self.ui_fgcolor = ui_fgcolor # Keep track of TaskGroups or tasks that depend on this entire TaskGroup separately # so that we can optimize the number of edges when entire TaskGroups depend on each other. self.upstream_group_ids: Set[Optional[str]] = set() self.downstream_group_ids: Set[Optional[str]] = set() self.upstream_task_ids = set() self.downstream_task_ids = set()
def __init__( self, group_id: Optional[str], prefix_group_id: bool = True, parent_group: Optional["TaskGroup"] = None, dag: Optional["DAG"] = None, default_args: Optional[Dict] = None, tooltip: str = "", ui_color: str = "CornflowerBlue", ui_fgcolor: str = "#000", add_suffix_on_collision: bool = False, ): from airflow.models.dag import DagContext self.prefix_group_id = prefix_group_id self.default_args = copy.deepcopy(default_args or {}) if group_id is None: # This creates a root TaskGroup. if parent_group: raise AirflowException( "Root TaskGroup cannot have parent_group") # used_group_ids is shared across all TaskGroups in the same DAG to keep track # of used group_id to avoid duplication. self.used_group_ids: Set[Optional[str]] = set() self._parent_group = None else: if not isinstance(group_id, str): raise ValueError("group_id must be str") if not group_id: raise ValueError("group_id must not be empty") dag = dag or DagContext.get_current_dag() if not parent_group and not dag: raise AirflowException( "TaskGroup can only be used inside a dag") self._parent_group = parent_group or TaskGroupContext.get_current_task_group( dag) if not self._parent_group: raise AirflowException( "TaskGroup must have a parent_group except for the root TaskGroup" ) self.used_group_ids = self._parent_group.used_group_ids self._group_id = group_id # if given group_id already used assign suffix by incrementing largest used suffix integer # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3 if group_id in self.used_group_ids: if not add_suffix_on_collision: raise DuplicateTaskIdFound( f"group_id '{self.group_id}' has already been added to the DAG" ) base = re.split(r'__\d+$', group_id)[0] suffixes = sorted( int(re.split(r'^.+__', used_group_id)[1]) for used_group_id in self.used_group_ids if used_group_id is not None and re.match(rf'^{base}__\d+$', used_group_id)) if not suffixes: self._group_id += '__1' else: self._group_id = f'{base}__{suffixes[-1] + 1}' self.used_group_ids.add(self.group_id) self.used_group_ids.add(self.downstream_join_id) self.used_group_ids.add(self.upstream_join_id) self.children: Dict[str, Union["BaseOperator", "TaskGroup"]] = {} if self._parent_group: self._parent_group.add(self) self.tooltip = tooltip self.ui_color = ui_color self.ui_fgcolor = ui_fgcolor # Keep track of TaskGroups or tasks that depend on this entire TaskGroup separately # so that we can optimize the number of edges when entire TaskGroups depend on each other. self.upstream_group_ids: Set[Optional[str]] = set() self.downstream_group_ids: Set[Optional[str]] = set() self.upstream_task_ids: Set[Optional[str]] = set() self.downstream_task_ids: Set[Optional[str]] = set()
def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: ensure_xcomarg_return_value(expand_input.value) task_kwargs = self.kwargs.copy() dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag() task_group = task_kwargs.pop( "task_group", None) or TaskGroupContext.get_current_task_group(dag) partial_kwargs, default_params = get_merged_defaults( dag=dag, task_group=task_group, task_params=task_kwargs.pop("params", None), task_default_args=task_kwargs.pop("default_args", None), ) partial_kwargs.update(task_kwargs) task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group) params = partial_kwargs.pop("params", None) or default_params # Logic here should be kept in sync with BaseOperatorMeta.partial(). if "task_concurrency" in partial_kwargs: raise TypeError("unexpected argument: task_concurrency") if partial_kwargs.get("wait_for_downstream"): partial_kwargs["depends_on_past"] = True start_date = timezone.convert_to_utc( partial_kwargs.pop("start_date", None)) end_date = timezone.convert_to_utc(partial_kwargs.pop( "end_date", None)) if partial_kwargs.get("pool") is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME partial_kwargs["retries"] = parse_retries( partial_kwargs.get("retries", DEFAULT_RETRIES)) partial_kwargs["retry_delay"] = coerce_timedelta( partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY), key="retry_delay", ) max_retry_delay = partial_kwargs.get("max_retry_delay") partial_kwargs["max_retry_delay"] = ( max_retry_delay if max_retry_delay is None else coerce_timedelta( max_retry_delay, key="max_retry_delay")) partial_kwargs["resources"] = coerce_resources( partial_kwargs.get("resources")) partial_kwargs.setdefault("executor_config", {}) partial_kwargs.setdefault("op_args", []) partial_kwargs.setdefault("op_kwargs", {}) # Mypy does not work well with a subclassed attrs class :( _MappedOperator = cast(Any, DecoratedMappedOperator) try: operator_name = self.operator_class.custom_operator_name # type: ignore except AttributeError: operator_name = self.operator_class.__name__ operator = _MappedOperator( operator_class=self.operator_class, expand_input= EXPAND_INPUT_EMPTY, # Don't use this; mapped values go to op_kwargs_expand_input. partial_kwargs=partial_kwargs, task_id=task_id, params=params, deps=MappedOperator.deps_for(self.operator_class), operator_extra_links=self.operator_class.operator_extra_links, template_ext=self.operator_class.template_ext, template_fields=self.operator_class.template_fields, template_fields_renderers=self.operator_class. template_fields_renderers, ui_color=self.operator_class.ui_color, ui_fgcolor=self.operator_class.ui_fgcolor, is_empty=False, task_module=self.operator_class.__module__, task_type=self.operator_class.__name__, operator_name=operator_name, dag=dag, task_group=task_group, start_date=start_date, end_date=end_date, multiple_outputs=self.multiple_outputs, python_callable=self.function, op_kwargs_expand_input=expand_input, disallow_kwargs_override=strict, # Different from classic operators, kwargs passed to a taskflow # task's expand() contribute to the op_kwargs operator argument, not # the operator arguments themselves, and should expand against it. expand_input_attr="op_kwargs_expand_input", ) return XComArg(operator=operator)
def expand(self, **map_kwargs: "Mappable") -> XComArg: self._validate_arg_names("expand", map_kwargs) prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping already partial") ensure_xcomarg_return_value(map_kwargs) task_kwargs = self.kwargs.copy() dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag() task_group = task_kwargs.pop( "task_group", None) or TaskGroupContext.get_current_task_group(dag) partial_kwargs, default_params = get_merged_defaults( dag=dag, task_group=task_group, task_params=task_kwargs.pop("params", None), task_default_args=task_kwargs.pop("default_args", None), ) partial_kwargs.update(task_kwargs) task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group) params = partial_kwargs.pop("params", None) or default_params # Logic here should be kept in sync with BaseOperatorMeta.partial(). if "task_concurrency" in partial_kwargs: raise TypeError("unexpected argument: task_concurrency") if partial_kwargs.get("wait_for_downstream"): partial_kwargs["depends_on_past"] = True start_date = timezone.convert_to_utc( partial_kwargs.pop("start_date", None)) end_date = timezone.convert_to_utc(partial_kwargs.pop( "end_date", None)) if partial_kwargs.get("pool") is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME partial_kwargs["retries"] = parse_retries( partial_kwargs.get("retries", DEFAULT_RETRIES)) partial_kwargs["retry_delay"] = coerce_retry_delay( partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY), ) partial_kwargs["resources"] = coerce_resources( partial_kwargs.get("resources")) partial_kwargs.setdefault("executor_config", {}) partial_kwargs.setdefault("op_args", []) partial_kwargs.setdefault("op_kwargs", {}) # Mypy does not work well with a subclassed attrs class :( _MappedOperator = cast(Any, DecoratedMappedOperator) operator = _MappedOperator( operator_class=self.operator_class, mapped_kwargs={}, partial_kwargs=partial_kwargs, task_id=task_id, params=params, deps=MappedOperator.deps_for(self.operator_class), operator_extra_links=self.operator_class.operator_extra_links, template_ext=self.operator_class.template_ext, template_fields=self.operator_class.template_fields, template_fields_renderers=self.operator_class. template_fields_renderers, ui_color=self.operator_class.ui_color, ui_fgcolor=self.operator_class.ui_fgcolor, is_empty=False, task_module=self.operator_class.__module__, task_type=self.operator_class.__name__, dag=dag, task_group=task_group, start_date=start_date, end_date=end_date, multiple_outputs=self.multiple_outputs, python_callable=self.function, mapped_op_kwargs=map_kwargs, # Different from classic operators, kwargs passed to a taskflow # task's expand() contribute to the op_kwargs operator argument, not # the operator arguments themselves, and should expand against it. expansion_kwargs_attr="mapped_op_kwargs", ) return XComArg(operator=operator)