Ejemplo n.º 1
0
    def _get_unique_task_id(task_id: str, dag: Optional[DAG] = None) -> str:
        """
        Generate unique task id given a DAG (or if run in a DAG context)
        Ids are generated by appending a unique number to the end of
        the original task id.

        Example:
          task_id
          task_id__1
          task_id__2
          ...
          task_id__20
        """
        dag = dag or DagContext.get_current_dag()
        if not dag or task_id not in dag.task_ids:
            return task_id
        core = re.split(r'__\d+$', task_id)[0]
        suffixes = sorted(
            [
                int(re.split(r'^.+__', task_id)[1])
                for task_id in dag.task_ids
                if re.match(rf'^{core}__\d+$', task_id)
            ]
        )
        if not suffixes:
            return f'{core}__1'
        return f'{core}__{suffixes[-1] + 1}'
Ejemplo n.º 2
0
def get_unique_task_id(
    task_id: str, dag: Optional[DAG] = None, task_group: Optional[TaskGroup] = None
) -> str:
    """
    Generate unique task id given a DAG (or if run in a DAG context)
    Ids are generated by appending a unique number to the end of
    the original task id.

    Example:
      task_id
      task_id__1
      task_id__2
      ...
      task_id__20
    """
    dag = dag or DagContext.get_current_dag()
    if not dag:
        return task_id

    # We need to check if we are in the context of TaskGroup as the task_id may
    # already be altered
    task_group = task_group or TaskGroupContext.get_current_task_group(dag)
    tg_task_id = task_group.child_id(task_id) if task_group else task_id

    if tg_task_id not in dag.task_ids:
        return task_id
    core = re.split(r'__\d+$', task_id)[0]
    suffixes = sorted(
        int(re.split(r'^.+__', task_id)[1])
        for task_id in dag.task_ids
        if re.match(rf'^{core}__\d+$', task_id)
    )
    if not suffixes:
        return f'{core}__1'
    return f'{core}__{suffixes[-1] + 1}'
        def operator_constructor(**kwargs):
            defaults['task_id'] = kwargs.pop('task_id', None) or defaults.get('task_id') or python_callable.__name__
            defaults['pool'] = kwargs.pop('pool', None) or defaults.get('pool')

            try:
                cmdag = settings.CONTEXT_MANAGER_DAG
            except AttributeError:  # Airflow 2.0+
                from airflow.models.dag import DagContext
                cmdag = DagContext.get_current_dag()

            dag = kwargs.get('dag', None) or defaults.get('dag', None) or cmdag
            dag_args = copy(dag.default_args) if dag else {}
            dag_params = copy(dag.params) if dag else {}
            default_args = {}

            if 'default_args' in defaults:
                default_args = defaults['default_args']
                if 'params' in default_args:
                    dag_params.update(default_args['params'])
                    del default_args['params']

            dag_args.update(default_args)
            default_args = dag_args

            for arg in signature.parameters:
                if arg not in kwargs and arg in default_args:
                    kwargs[arg] = default_args[arg]

            return PythonOperator(
                python_callable=python_callable,
                op_kwargs=kwargs,
                params=dag_params,
                **defaults,
            )
Ejemplo n.º 4
0
    def map(self, **kwargs: "MapArgument") -> XComArg:
        self._validate_arg_names("map", kwargs)

        partial_kwargs = self.kwargs.copy()

        dag = partial_kwargs.pop("dag", DagContext.get_current_dag())
        task_group = partial_kwargs.pop(
            "task_group", TaskGroupContext.get_current_task_group(dag))
        task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag,
                                     task_group)
        params = partial_kwargs.pop("params", None)

        # Logic here should be kept in sync with BaseOperatorMeta.partial().
        if "task_concurrency" in partial_kwargs:
            raise TypeError("unexpected argument: task_concurrency")
        if partial_kwargs.get("wait_for_downstream"):
            partial_kwargs["depends_on_past"] = True
        start_date = timezone.convert_to_utc(
            partial_kwargs.pop("start_date", None))
        end_date = timezone.convert_to_utc(partial_kwargs.pop(
            "end_date", None))
        if partial_kwargs.get("pool") is None:
            partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
        partial_kwargs["retries"] = parse_retries(
            partial_kwargs.get("retries", DEFAULT_RETRIES))
        partial_kwargs["retry_delay"] = coerce_retry_delay(
            partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY), )
        partial_kwargs["resources"] = coerce_resources(
            partial_kwargs.get("resources"))
        partial_kwargs.setdefault("executor_config", {})
        partial_kwargs.setdefault("op_args", [])
        partial_kwargs.setdefault("op_kwargs", {})

        # Mypy does not work well with a subclassed attrs class :(
        _MappedOperator = cast(Any, DecoratedMappedOperator)
        operator = _MappedOperator(
            operator_class=self.operator_class,
            mapped_kwargs={},
            partial_kwargs=partial_kwargs,
            task_id=task_id,
            params=params,
            deps=MappedOperator.deps_for(self.operator_class),
            operator_extra_links=self.operator_class.operator_extra_links,
            template_ext=self.operator_class.template_ext,
            template_fields=self.operator_class.template_fields,
            ui_color=self.operator_class.ui_color,
            ui_fgcolor=self.operator_class.ui_fgcolor,
            is_dummy=False,
            task_module=self.operator_class.__module__,
            task_type=self.operator_class.__name__,
            dag=dag,
            task_group=task_group,
            start_date=start_date,
            end_date=end_date,
            multiple_outputs=self.multiple_outputs,
            python_callable=self.function,
            mapped_op_kwargs=kwargs,
        )
        return XComArg(operator=operator)
Ejemplo n.º 5
0
    def __init__(
        self,
        group_id: Optional[str],
        prefix_group_id: bool = True,
        parent_group: Optional["TaskGroup"] = None,
        dag: Optional["DAG"] = None,
        tooltip: str = "",
        ui_color: str = "CornflowerBlue",
        ui_fgcolor: str = "#000",
    ):
        from airflow.models.dag import DagContext

        self.prefix_group_id = prefix_group_id

        if group_id is None:
            # This creates a root TaskGroup.
            if parent_group:
                raise AirflowException("Root TaskGroup cannot have parent_group")
            # used_group_ids is shared across all TaskGroups in the same DAG to keep track
            # of used group_id to avoid duplication.
            self.used_group_ids: Set[Optional[str]] = set()
            self._parent_group = None
        else:
            if not isinstance(group_id, str):
                raise ValueError("group_id must be str")
            if not group_id:
                raise ValueError("group_id must not be empty")

            dag = dag or DagContext.get_current_dag()

            if not parent_group and not dag:
                raise AirflowException("TaskGroup can only be used inside a dag")

            self._parent_group = parent_group or TaskGroupContext.get_current_task_group(dag)
            if not self._parent_group:
                raise AirflowException("TaskGroup must have a parent_group except for the root TaskGroup")
            self.used_group_ids = self._parent_group.used_group_ids

        self._group_id = group_id
        if self.group_id in self.used_group_ids:
            raise DuplicateTaskIdFound(f"group_id '{self.group_id}' has already been added to the DAG")
        self.used_group_ids.add(self.group_id)
        self.used_group_ids.add(self.downstream_join_id)
        self.used_group_ids.add(self.upstream_join_id)
        self.children: Dict[str, Union["BaseOperator", "TaskGroup"]] = {}
        if self._parent_group:
            self._parent_group.add(self)

        self.tooltip = tooltip
        self.ui_color = ui_color
        self.ui_fgcolor = ui_fgcolor

        # Keep track of TaskGroups or tasks that depend on this entire TaskGroup separately
        # so that we can optimize the number of edges when entire TaskGroups depend on each other.
        self.upstream_group_ids: Set[Optional[str]] = set()
        self.downstream_group_ids: Set[Optional[str]] = set()
        self.upstream_task_ids: Set[Optional[str]] = set()
        self.downstream_task_ids: Set[Optional[str]] = set()
Ejemplo n.º 6
0
    def get_current_task_group(cls, dag: Optional["DAG"]) -> Optional[TaskGroup]:
        """Get the current TaskGroup."""
        from airflow.models.dag import DagContext

        if not cls._context_managed_task_group:
            dag = dag or DagContext.get_current_dag()
            if dag:
                # If there's currently a DAG but no TaskGroup, return the root TaskGroup of the dag.
                return dag.task_group

        return cls._context_managed_task_group
Ejemplo n.º 7
0
    def _validate_dag(self, kwargs):
        dag = kwargs.get('dag') or DagContext.get_current_dag()

        if not dag:
            raise AirflowException('Please pass in the `dag` param or call within a DAG context manager')

        if dag.dag_id + '.' + kwargs['task_id'] != self.subdag.dag_id:
            raise AirflowException(
                "The subdag's dag_id should have the form '{{parent_dag_id}}.{{this_task_id}}'. "
                "Expected '{d}.{t}'; received '{rcvd}'.".format(
                    d=dag.dag_id, t=kwargs['task_id'], rcvd=self.subdag.dag_id)
            )
Ejemplo n.º 8
0
    def map(
        self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs
    ) -> XComArg:
        self._validate_arg_names("map", kwargs)
        dag = dag or DagContext.get_current_dag()
        task_group = task_group or TaskGroupContext.get_current_task_group(dag)
        task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group)

        operator = MappedOperator.from_decorator(
            decorator=self,
            dag=dag,
            task_group=task_group,
            task_id=task_id,
            mapped_kwargs=kwargs,
        )
        return XComArg(operator=operator)
Ejemplo n.º 9
0
def safe_get_context_manager_dag():
    """
    Try to find the CONTEXT_MANAGER_DAG object inside airflow.
    It was moved between versions, so we look for it in all the hiding places that we know of.
    """
    if AIRFLOW_VERSION_2:
        from airflow.models.dag import DagContext

        return DagContext.get_current_dag()

    if hasattr(settings, "CONTEXT_MANAGER_DAG"):
        return settings.CONTEXT_MANAGER_DAG
    elif hasattr(DAG, "_CONTEXT_MANAGER_DAG"):
        return DAG._CONTEXT_MANAGER_DAG
    elif hasattr(models, "_CONTEXT_MANAGER_DAG"):
        return models._CONTEXT_MANAGER_DAG

    return None
Ejemplo n.º 10
0
def get_unique_task_id(
    task_id: str,
    dag: Optional[DAG] = None,
    task_group: Optional[TaskGroup] = None,
) -> str:
    """
    Generate unique task id given a DAG (or if run in a DAG context)
    Ids are generated by appending a unique number to the end of
    the original task id.

    Example:
      task_id
      task_id__1
      task_id__2
      ...
      task_id__20
    """
    dag = dag or DagContext.get_current_dag()
    if not dag:
        return task_id

    # We need to check if we are in the context of TaskGroup as the task_id may
    # already be altered
    task_group = task_group or TaskGroupContext.get_current_task_group(dag)
    tg_task_id = task_group.child_id(task_id) if task_group else task_id

    if tg_task_id not in dag.task_ids:
        return task_id

    def _find_id_suffixes(dag: DAG) -> Iterator[int]:
        prefix = re.split(r"__\d+$", tg_task_id)[0]
        for task_id in dag.task_ids:
            match = re.match(rf"^{prefix}__(\d+)$", task_id)
            if match is None:
                continue
            yield int(match.group(1))
        yield 0  # Default if there's no matching task ID.

    core = re.split(r"__\d+$", task_id)[0]
    return f"{core}__{max(_find_id_suffixes(dag)) + 1}"
Ejemplo n.º 11
0
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        from airflow.models.dag import DagContext

        if len(args) > 1:
            raise AirflowException(
                "Use keyword arguments when initializing operators")
        dag_args: Dict[str, Any] = {}
        dag_params: Dict[str, Any] = {}

        dag = kwargs.get('dag') or DagContext.get_current_dag()
        if dag:
            dag_args = copy(dag.default_args) or {}
            dag_params = copy(dag.params) or {}

        params = kwargs.get('params', {}) or {}
        dag_params.update(params)

        default_args = {}
        if 'default_args' in kwargs:
            default_args = kwargs['default_args']
            if 'params' in default_args:
                dag_params.update(default_args['params'])
                del default_args['params']

        dag_args.update(default_args)
        default_args = dag_args

        for arg in sig_cache.parameters:
            if arg not in kwargs and arg in default_args:
                kwargs[arg] = default_args[arg]

        missing_args = list(non_optional_args - set(kwargs))
        if missing_args:
            msg = f"Argument {missing_args} is required"
            raise AirflowException(msg)

        kwargs['params'] = dag_params

        result = func(*args, **kwargs)
        return result
Ejemplo n.º 12
0
    def map(self,
            *,
            dag: Optional["DAG"] = None,
            task_group: Optional["TaskGroup"] = None,
            **kwargs) -> XComArg:

        dag = dag or DagContext.get_current_dag()
        task_group = task_group or TaskGroupContext.get_current_task_group(dag)
        task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group)

        self._validate_arg_names("map", kwargs)

        operator = MappedOperator(
            operator_class=self.operator_class,
            task_id=task_id,
            dag=dag,
            task_group=task_group,
            partial_kwargs=self.kwargs,
            # Set them to empty to bypass the validation, as for decorated stuff we validate ourselves
            mapped_kwargs={},
        )
        operator.mapped_kwargs.update(kwargs)

        return XComArg(operator=operator)
Ejemplo n.º 13
0
    def __init__(
            self,
            task_id: str,
            owner: str = conf.get('operators', 'DEFAULT_OWNER'),
            email: Optional[Union[str, Iterable[str]]] = None,
            email_on_retry: bool = True,
            email_on_failure: bool = True,
            retries: Optional[int] = conf.getint('core',
                                                 'default_task_retries',
                                                 fallback=0),
            retry_delay: timedelta = timedelta(seconds=300),
            retry_exponential_backoff: bool = False,
            max_retry_delay: Optional[datetime] = None,
            start_date: Optional[datetime] = None,
            end_date: Optional[datetime] = None,
            depends_on_past: bool = False,
            wait_for_downstream: bool = False,
            dag=None,
            params: Optional[Dict] = None,
            default_args: Optional[Dict] = None,  # pylint: disable=unused-argument
            priority_weight: int = 1,
            weight_rule: str = WeightRule.DOWNSTREAM,
            queue: str = conf.get('celery', 'default_queue'),
            pool: str = Pool.DEFAULT_POOL_NAME,
            sla: Optional[timedelta] = None,
            execution_timeout: Optional[timedelta] = None,
            on_execute_callback: Optional[Callable] = None,
            on_failure_callback: Optional[Callable] = None,
            on_success_callback: Optional[Callable] = None,
            on_retry_callback: Optional[Callable] = None,
            trigger_rule: str = TriggerRule.ALL_SUCCESS,
            resources: Optional[Dict] = None,
            run_as_user: Optional[str] = None,
            task_concurrency: Optional[int] = None,
            executor_config: Optional[Dict] = None,
            do_xcom_push: bool = True,
            inlets: Optional[Any] = None,
            outlets: Optional[Any] = None,
            *args,
            **kwargs):
        from airflow.models.dag import DagContext
        super().__init__()
        if args or kwargs:
            if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'):
                raise AirflowException(
                    "Invalid arguments were passed to {c} (task_id: {t}). Invalid "
                    "arguments were:\n*args: {a}\n**kwargs: {k}".format(
                        c=self.__class__.__name__, a=args, k=kwargs,
                        t=task_id), )
            warnings.warn(
                'Invalid arguments were passed to {c} (task_id: {t}). '
                'Support for passing such arguments will be dropped in '
                'future. Invalid arguments were:'
                '\n*args: {a}\n**kwargs: {k}'.format(c=self.__class__.__name__,
                                                     a=args,
                                                     k=kwargs,
                                                     t=task_id),
                category=PendingDeprecationWarning,
                stacklevel=3)
        validate_key(task_id)
        self.task_id = task_id
        self.owner = owner
        self.email = email
        self.email_on_retry = email_on_retry
        self.email_on_failure = email_on_failure

        self.start_date = start_date
        if start_date and not isinstance(start_date, datetime):
            self.log.warning("start_date for %s isn't datetime.datetime", self)
        elif start_date:
            self.start_date = timezone.convert_to_utc(start_date)

        self.end_date = end_date
        if end_date:
            self.end_date = timezone.convert_to_utc(end_date)

        if not TriggerRule.is_valid(trigger_rule):
            raise AirflowException(
                "The trigger_rule must be one of {all_triggers},"
                "'{d}.{t}'; received '{tr}'.".format(
                    all_triggers=TriggerRule.all_triggers(),
                    d=dag.dag_id if dag else "",
                    t=task_id,
                    tr=trigger_rule))

        self.trigger_rule = trigger_rule
        self.depends_on_past = depends_on_past
        self.wait_for_downstream = wait_for_downstream
        if wait_for_downstream:
            self.depends_on_past = True

        self.retries = retries
        self.queue = queue
        self.pool = pool
        self.sla = sla
        self.execution_timeout = execution_timeout
        self.on_execute_callback = on_execute_callback
        self.on_failure_callback = on_failure_callback
        self.on_success_callback = on_success_callback
        self.on_retry_callback = on_retry_callback

        if isinstance(retry_delay, timedelta):
            self.retry_delay = retry_delay
        else:
            self.log.debug("Retry_delay isn't timedelta object, assuming secs")
            # noinspection PyTypeChecker
            self.retry_delay = timedelta(seconds=retry_delay)
        self.retry_exponential_backoff = retry_exponential_backoff
        self.max_retry_delay = max_retry_delay
        self.params = params or {}  # Available in templates!
        self.priority_weight = priority_weight
        if not WeightRule.is_valid(weight_rule):
            raise AirflowException(
                "The weight_rule must be one of {all_weight_rules},"
                "'{d}.{t}'; received '{tr}'.".format(
                    all_weight_rules=WeightRule.all_weight_rules,
                    d=dag.dag_id if dag else "",
                    t=task_id,
                    tr=weight_rule))
        self.weight_rule = weight_rule
        self.resources: Optional[Resources] = Resources(
            **resources) if resources else None
        self.run_as_user = run_as_user
        self.task_concurrency = task_concurrency
        self.executor_config = executor_config or {}
        self.do_xcom_push = do_xcom_push

        # Private attributes
        self._upstream_task_ids: Set[str] = set()
        self._downstream_task_ids: Set[str] = set()
        self._dag = None

        self.dag = dag or DagContext.get_current_dag()

        # subdag parameter is only set for SubDagOperator.
        # Setting it to None by default as other Operators do not have that field
        from airflow.models.dag import DAG
        self.subdag: Optional[DAG] = None

        self._log = logging.getLogger("airflow.task.operators")

        # Lineage
        self.inlets: List = []
        self.outlets: List = []

        self._inlets: List = []
        self._outlets: List = []

        if inlets:
            self._inlets = inlets if isinstance(inlets, list) else [
                inlets,
            ]

        if outlets:
            self._outlets = outlets if isinstance(outlets, list) else [
                outlets,
            ]
Ejemplo n.º 14
0
    def __init__(
        self,
        group_id: Optional[str],
        prefix_group_id: bool = True,
        parent_group: Optional["TaskGroup"] = None,
        dag: Optional["DAG"] = None,
        default_args: Optional[Dict] = None,
        tooltip: str = "",
        ui_color: str = "CornflowerBlue",
        ui_fgcolor: str = "#000",
        add_suffix_on_collision: bool = False,
    ):
        from airflow.models.dag import DagContext

        self.prefix_group_id = prefix_group_id
        self.default_args = copy.deepcopy(default_args or {})

        dag = dag or DagContext.get_current_dag()

        if group_id is None:
            # This creates a root TaskGroup.
            if parent_group:
                raise AirflowException(
                    "Root TaskGroup cannot have parent_group")
            # used_group_ids is shared across all TaskGroups in the same DAG to keep track
            # of used group_id to avoid duplication.
            self.used_group_ids = set()
            self.dag = dag
        else:
            if prefix_group_id:
                # If group id is used as prefix, it should not contain spaces nor dots
                # because it is used as prefix in the task_id
                validate_group_key(group_id)
            else:
                if not isinstance(group_id, str):
                    raise ValueError("group_id must be str")
                if not group_id:
                    raise ValueError("group_id must not be empty")

            if not parent_group and not dag:
                raise AirflowException(
                    "TaskGroup can only be used inside a dag")

            parent_group = parent_group or TaskGroupContext.get_current_task_group(
                dag)
            if not parent_group:
                raise AirflowException(
                    "TaskGroup must have a parent_group except for the root TaskGroup"
                )
            if dag is not parent_group.dag:
                raise RuntimeError(
                    "Cannot mix TaskGroups from different DAGs: %s and %s",
                    dag, parent_group.dag)

            self.used_group_ids = parent_group.used_group_ids

        # if given group_id already used assign suffix by incrementing largest used suffix integer
        # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3
        self._group_id = group_id
        self._check_for_group_id_collisions(add_suffix_on_collision)

        self.children: Dict[str, DAGNode] = {}
        if parent_group:
            parent_group.add(self)

        self.used_group_ids.add(self.group_id)
        if self.group_id:
            self.used_group_ids.add(self.downstream_join_id)
            self.used_group_ids.add(self.upstream_join_id)

        self.tooltip = tooltip
        self.ui_color = ui_color
        self.ui_fgcolor = ui_fgcolor

        # Keep track of TaskGroups or tasks that depend on this entire TaskGroup separately
        # so that we can optimize the number of edges when entire TaskGroups depend on each other.
        self.upstream_group_ids: Set[Optional[str]] = set()
        self.downstream_group_ids: Set[Optional[str]] = set()
        self.upstream_task_ids = set()
        self.downstream_task_ids = set()
Ejemplo n.º 15
0
    def __init__(
        self,
        group_id: Optional[str],
        prefix_group_id: bool = True,
        parent_group: Optional["TaskGroup"] = None,
        dag: Optional["DAG"] = None,
        default_args: Optional[Dict] = None,
        tooltip: str = "",
        ui_color: str = "CornflowerBlue",
        ui_fgcolor: str = "#000",
        add_suffix_on_collision: bool = False,
    ):
        from airflow.models.dag import DagContext

        self.prefix_group_id = prefix_group_id
        self.default_args = copy.deepcopy(default_args or {})

        if group_id is None:
            # This creates a root TaskGroup.
            if parent_group:
                raise AirflowException(
                    "Root TaskGroup cannot have parent_group")
            # used_group_ids is shared across all TaskGroups in the same DAG to keep track
            # of used group_id to avoid duplication.
            self.used_group_ids: Set[Optional[str]] = set()
            self._parent_group = None
        else:
            if not isinstance(group_id, str):
                raise ValueError("group_id must be str")
            if not group_id:
                raise ValueError("group_id must not be empty")

            dag = dag or DagContext.get_current_dag()

            if not parent_group and not dag:
                raise AirflowException(
                    "TaskGroup can only be used inside a dag")

            self._parent_group = parent_group or TaskGroupContext.get_current_task_group(
                dag)
            if not self._parent_group:
                raise AirflowException(
                    "TaskGroup must have a parent_group except for the root TaskGroup"
                )
            self.used_group_ids = self._parent_group.used_group_ids

        self._group_id = group_id
        # if given group_id already used assign suffix by incrementing largest used suffix integer
        # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3
        if group_id in self.used_group_ids:
            if not add_suffix_on_collision:
                raise DuplicateTaskIdFound(
                    f"group_id '{self.group_id}' has already been added to the DAG"
                )
            base = re.split(r'__\d+$', group_id)[0]
            suffixes = sorted(
                int(re.split(r'^.+__', used_group_id)[1])
                for used_group_id in self.used_group_ids
                if used_group_id is not None
                and re.match(rf'^{base}__\d+$', used_group_id))
            if not suffixes:
                self._group_id += '__1'
            else:
                self._group_id = f'{base}__{suffixes[-1] + 1}'

        self.used_group_ids.add(self.group_id)
        self.used_group_ids.add(self.downstream_join_id)
        self.used_group_ids.add(self.upstream_join_id)
        self.children: Dict[str, Union["BaseOperator", "TaskGroup"]] = {}
        if self._parent_group:
            self._parent_group.add(self)

        self.tooltip = tooltip
        self.ui_color = ui_color
        self.ui_fgcolor = ui_fgcolor

        # Keep track of TaskGroups or tasks that depend on this entire TaskGroup separately
        # so that we can optimize the number of edges when entire TaskGroups depend on each other.
        self.upstream_group_ids: Set[Optional[str]] = set()
        self.downstream_group_ids: Set[Optional[str]] = set()
        self.upstream_task_ids: Set[Optional[str]] = set()
        self.downstream_task_ids: Set[Optional[str]] = set()
Ejemplo n.º 16
0
    def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
        ensure_xcomarg_return_value(expand_input.value)

        task_kwargs = self.kwargs.copy()
        dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag()
        task_group = task_kwargs.pop(
            "task_group", None) or TaskGroupContext.get_current_task_group(dag)

        partial_kwargs, default_params = get_merged_defaults(
            dag=dag,
            task_group=task_group,
            task_params=task_kwargs.pop("params", None),
            task_default_args=task_kwargs.pop("default_args", None),
        )
        partial_kwargs.update(task_kwargs)

        task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag,
                                     task_group)
        params = partial_kwargs.pop("params", None) or default_params

        # Logic here should be kept in sync with BaseOperatorMeta.partial().
        if "task_concurrency" in partial_kwargs:
            raise TypeError("unexpected argument: task_concurrency")
        if partial_kwargs.get("wait_for_downstream"):
            partial_kwargs["depends_on_past"] = True
        start_date = timezone.convert_to_utc(
            partial_kwargs.pop("start_date", None))
        end_date = timezone.convert_to_utc(partial_kwargs.pop(
            "end_date", None))
        if partial_kwargs.get("pool") is None:
            partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
        partial_kwargs["retries"] = parse_retries(
            partial_kwargs.get("retries", DEFAULT_RETRIES))
        partial_kwargs["retry_delay"] = coerce_timedelta(
            partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY),
            key="retry_delay",
        )
        max_retry_delay = partial_kwargs.get("max_retry_delay")
        partial_kwargs["max_retry_delay"] = (
            max_retry_delay if max_retry_delay is None else coerce_timedelta(
                max_retry_delay, key="max_retry_delay"))
        partial_kwargs["resources"] = coerce_resources(
            partial_kwargs.get("resources"))
        partial_kwargs.setdefault("executor_config", {})
        partial_kwargs.setdefault("op_args", [])
        partial_kwargs.setdefault("op_kwargs", {})

        # Mypy does not work well with a subclassed attrs class :(
        _MappedOperator = cast(Any, DecoratedMappedOperator)

        try:
            operator_name = self.operator_class.custom_operator_name  # type: ignore
        except AttributeError:
            operator_name = self.operator_class.__name__

        operator = _MappedOperator(
            operator_class=self.operator_class,
            expand_input=
            EXPAND_INPUT_EMPTY,  # Don't use this; mapped values go to op_kwargs_expand_input.
            partial_kwargs=partial_kwargs,
            task_id=task_id,
            params=params,
            deps=MappedOperator.deps_for(self.operator_class),
            operator_extra_links=self.operator_class.operator_extra_links,
            template_ext=self.operator_class.template_ext,
            template_fields=self.operator_class.template_fields,
            template_fields_renderers=self.operator_class.
            template_fields_renderers,
            ui_color=self.operator_class.ui_color,
            ui_fgcolor=self.operator_class.ui_fgcolor,
            is_empty=False,
            task_module=self.operator_class.__module__,
            task_type=self.operator_class.__name__,
            operator_name=operator_name,
            dag=dag,
            task_group=task_group,
            start_date=start_date,
            end_date=end_date,
            multiple_outputs=self.multiple_outputs,
            python_callable=self.function,
            op_kwargs_expand_input=expand_input,
            disallow_kwargs_override=strict,
            # Different from classic operators, kwargs passed to a taskflow
            # task's expand() contribute to the op_kwargs operator argument, not
            # the operator arguments themselves, and should expand against it.
            expand_input_attr="op_kwargs_expand_input",
        )
        return XComArg(operator=operator)
Ejemplo n.º 17
0
    def expand(self, **map_kwargs: "Mappable") -> XComArg:
        self._validate_arg_names("expand", map_kwargs)
        prevent_duplicates(self.kwargs,
                           map_kwargs,
                           fail_reason="mapping already partial")
        ensure_xcomarg_return_value(map_kwargs)

        task_kwargs = self.kwargs.copy()
        dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag()
        task_group = task_kwargs.pop(
            "task_group", None) or TaskGroupContext.get_current_task_group(dag)

        partial_kwargs, default_params = get_merged_defaults(
            dag=dag,
            task_group=task_group,
            task_params=task_kwargs.pop("params", None),
            task_default_args=task_kwargs.pop("default_args", None),
        )
        partial_kwargs.update(task_kwargs)

        task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag,
                                     task_group)
        params = partial_kwargs.pop("params", None) or default_params

        # Logic here should be kept in sync with BaseOperatorMeta.partial().
        if "task_concurrency" in partial_kwargs:
            raise TypeError("unexpected argument: task_concurrency")
        if partial_kwargs.get("wait_for_downstream"):
            partial_kwargs["depends_on_past"] = True
        start_date = timezone.convert_to_utc(
            partial_kwargs.pop("start_date", None))
        end_date = timezone.convert_to_utc(partial_kwargs.pop(
            "end_date", None))
        if partial_kwargs.get("pool") is None:
            partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
        partial_kwargs["retries"] = parse_retries(
            partial_kwargs.get("retries", DEFAULT_RETRIES))
        partial_kwargs["retry_delay"] = coerce_retry_delay(
            partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY), )
        partial_kwargs["resources"] = coerce_resources(
            partial_kwargs.get("resources"))
        partial_kwargs.setdefault("executor_config", {})
        partial_kwargs.setdefault("op_args", [])
        partial_kwargs.setdefault("op_kwargs", {})

        # Mypy does not work well with a subclassed attrs class :(
        _MappedOperator = cast(Any, DecoratedMappedOperator)
        operator = _MappedOperator(
            operator_class=self.operator_class,
            mapped_kwargs={},
            partial_kwargs=partial_kwargs,
            task_id=task_id,
            params=params,
            deps=MappedOperator.deps_for(self.operator_class),
            operator_extra_links=self.operator_class.operator_extra_links,
            template_ext=self.operator_class.template_ext,
            template_fields=self.operator_class.template_fields,
            template_fields_renderers=self.operator_class.
            template_fields_renderers,
            ui_color=self.operator_class.ui_color,
            ui_fgcolor=self.operator_class.ui_fgcolor,
            is_empty=False,
            task_module=self.operator_class.__module__,
            task_type=self.operator_class.__name__,
            dag=dag,
            task_group=task_group,
            start_date=start_date,
            end_date=end_date,
            multiple_outputs=self.multiple_outputs,
            python_callable=self.function,
            mapped_op_kwargs=map_kwargs,
            # Different from classic operators, kwargs passed to a taskflow
            # task's expand() contribute to the op_kwargs operator argument, not
            # the operator arguments themselves, and should expand against it.
            expansion_kwargs_attr="mapped_op_kwargs",
        )
        return XComArg(operator=operator)