def __attrs_post_init__(self): # The magic super() doesn't work here, so we use the explicit form. # Not using super(..., self) to work around pyupgrade bug. super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self) XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value)
def test_xcom_ctor(self): python_op = build_python_op() actual = XComArg(python_op, "test_key") assert actual assert actual.operator == python_op assert actual.key == "test_key" # Asserting the overridden __eq__ method assert actual == XComArg(python_op, "test_key") assert str(actual) == "task_instance.xcom_pull(" \ "task_ids=\'test_xcom_op\', " \ "dag_id=\'test_xcom_dag\', " \ "key=\'test_key\')"
def test_set_upstream(self): with DAG("test_set_upstream", default_args=DEFAULT_ARGS): op_a = BashOperator(task_id="a", bash_command="echo a") op_b = BashOperator(task_id="b", bash_command="echo b") bash_op = BashOperator(task_id="c", bash_command="echo c") xcom_args_a = XComArg(op_a) xcom_args_b = XComArg(op_b) xcom_args_a << xcom_args_b << bash_op assert len(op_a.upstream_list) == 2 assert op_b in op_a.upstream_list assert bash_op in op_a.upstream_list
def test_xcomarg_set(self, test_dag): """Tests the set_upstream/downstream style with an XComArg""" # Unpack the fixture dag, (op1, op2, op3, op4) = test_dag # Arrange the operators with a Label in the middle op1_arg = XComArg(op1, "test_key") op1_arg.set_downstream(op2, Label("Label 1")) op1.set_downstream([op3, op4]) # Check that the DAG has the right edge info assert dag.get_edge_info(op1.task_id, op2.task_id) == { "label": "Label 1" } assert dag.get_edge_info(op1.task_id, op4.task_id) == {}
def test_set_downstream(self): with DAG("test_set_downstream", default_args=DEFAULT_ARGS): op_a = BashOperator(task_id="a", bash_command="echo a") op_b = BashOperator(task_id="b", bash_command="echo b") bash_op1 = BashOperator(task_id="c", bash_command="echo c") bash_op2 = BashOperator(task_id="d", bash_command="echo c") xcom_args_a = XComArg(op_a) xcom_args_b = XComArg(op_b) bash_op1 >> xcom_args_a >> xcom_args_b >> bash_op2 assert op_a in bash_op1.downstream_list assert op_b in op_a.downstream_list assert bash_op2 in op_b.downstream_list
def __attrs_post_init__(self): from airflow.models.xcom_arg import XComArg self._validate_argument_count() if self.task_group: self.task_group.add(self) if self.dag: self.dag.add_task(self) for k, v in self.mapped_kwargs.items(): if k in self.template_fields: XComArg.apply_upstream_relationship(self, v) for k, v in self.partial_kwargs.items(): if k in self.template_fields: XComArg.apply_upstream_relationship(self, v)
def test_xcom_arg(self): """Tests that returned key in XComArg is returned correctly""" @task_decorator def add_2(number: int): return number + 2 @task_decorator def add_num(number: int, num2: int = 2): return number + num2 test_number = 10 with self.dag: bigger_number = add_2(test_number) ret = add_num(bigger_number, XComArg(bigger_number.operator)) # pylint: disable=maybe-no-member dr = self.dag.create_dagrun(run_id=DagRunType.MANUAL.value, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING) bigger_number.operator.run( # pylint: disable=maybe-no-member start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # pylint: disable=maybe-no-member ti_add_num = [ ti for ti in dr.get_task_instances() if ti.task_id == 'add_num' ][0] assert ti_add_num.xcom_pull(key=ret.key) == (test_number + 2) * 2 # pylint: disable=maybe-no-member
def __attrs_post_init__(self): from airflow.models.xcom_arg import XComArg self._validate_argument_count() if self.task_group: self.task_group.add(self) if self.dag: self.dag.add_task(self) XComArg.apply_upstream_relationship(self, self.expand_input.value) for k, v in self.partial_kwargs.items(): if k in self.template_fields: XComArg.apply_upstream_relationship(self, v) if self.partial_kwargs.get('sla') is not None: raise AirflowException( f"SLAs are unsupported with mapped tasks. Please set `sla=None` for task " f"{self.task_id!r}.")
def test_xcom_key_is_empty_str(self): python_op = build_python_op() actual = XComArg(python_op, key="") assert actual.key == "" assert ( str(actual) == "task_instance.xcom_pull(task_ids='test_xcom_op', " "dag_id='test_xcom_dag', key='')")
def iter_mapped_dependencies(self) -> Iterator["Operator"]: """Upstream dependencies that provide XComs used by this task for task mapping.""" from airflow.models.xcom_arg import XComArg for ref in XComArg.iter_xcom_args(self._get_specified_expand_input()): for operator, _ in ref.iter_references(): yield operator
def factory(*args, **f_kwargs): op = _PythonDecoratedOperator(python_callable=f, op_args=args, op_kwargs=f_kwargs, multiple_outputs=multiple_outputs, **kwargs) return XComArg(op)
def map(self, **kwargs: "MapArgument") -> XComArg: self._validate_arg_names("map", kwargs) partial_kwargs = self.kwargs.copy() dag = partial_kwargs.pop("dag", DagContext.get_current_dag()) task_group = partial_kwargs.pop( "task_group", TaskGroupContext.get_current_task_group(dag)) task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group) params = partial_kwargs.pop("params", None) # Logic here should be kept in sync with BaseOperatorMeta.partial(). if "task_concurrency" in partial_kwargs: raise TypeError("unexpected argument: task_concurrency") if partial_kwargs.get("wait_for_downstream"): partial_kwargs["depends_on_past"] = True start_date = timezone.convert_to_utc( partial_kwargs.pop("start_date", None)) end_date = timezone.convert_to_utc(partial_kwargs.pop( "end_date", None)) if partial_kwargs.get("pool") is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME partial_kwargs["retries"] = parse_retries( partial_kwargs.get("retries", DEFAULT_RETRIES)) partial_kwargs["retry_delay"] = coerce_retry_delay( partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY), ) partial_kwargs["resources"] = coerce_resources( partial_kwargs.get("resources")) partial_kwargs.setdefault("executor_config", {}) partial_kwargs.setdefault("op_args", []) partial_kwargs.setdefault("op_kwargs", {}) # Mypy does not work well with a subclassed attrs class :( _MappedOperator = cast(Any, DecoratedMappedOperator) operator = _MappedOperator( operator_class=self.operator_class, mapped_kwargs={}, partial_kwargs=partial_kwargs, task_id=task_id, params=params, deps=MappedOperator.deps_for(self.operator_class), operator_extra_links=self.operator_class.operator_extra_links, template_ext=self.operator_class.template_ext, template_fields=self.operator_class.template_fields, ui_color=self.operator_class.ui_color, ui_fgcolor=self.operator_class.ui_fgcolor, is_dummy=False, task_module=self.operator_class.__module__, task_type=self.operator_class.__name__, dag=dag, task_group=task_group, start_date=start_date, end_date=end_date, multiple_outputs=self.multiple_outputs, python_callable=self.function, mapped_op_kwargs=kwargs, ) return XComArg(operator=operator)
def test_xcom_ctor(self): python_op = build_python_op() actual = XComArg(python_op, "test_key") assert actual assert actual.operator == python_op assert actual.key == "test_key" # Asserting the overridden __eq__ method assert actual == XComArg(python_op, "test_key") expected_str = ( "{{ task_instance.xcom_pull(task_ids='test_xcom_op', " "dag_id='test_xcom_dag', key='test_key') }}" ) assert str(actual) == expected_str assert ( f"echo {actual}" == "echo {{ task_instance.xcom_pull(task_ids='test_xcom_op', " "dag_id='test_xcom_dag', key='test_key') }}" )
def factory(*args, **f_kwargs): op = decorated_operator_class( python_callable=f, op_args=args, op_kwargs=f_kwargs, multiple_outputs=multiple_outputs, **kwargs, ) if f.__doc__: op.doc_md = f.__doc__ return XComArg(op)
def __call__(self, *args, **kwargs) -> XComArg: op = self.operator_class( python_callable=self.function, op_args=args, op_kwargs=kwargs, multiple_outputs=self.multiple_outputs, **self.kwargs, ) if self.function.__doc__: op.doc_md = self.function.__doc__ return XComArg(op)
def factory(*args, **f_kwargs): ff = ExperimentStep()(f) op = _PythonDecoratedOperator( python_callable=ff, op_args=args, op_kwargs=f_kwargs, multiple_outputs=multiple_outputs, **kwargs, ) if f.__doc__: op.doc_md = f.__doc__ return XComArg(op)
def test_xcomarg_shift(self, test_dag): """Tests the >> / << style with an XComArg""" # Unpack the fixture dag, (op1, op2, op3, op4) = test_dag # Arrange the operators with a Label in the middle op1_arg = XComArg(op1, "test_key") op1_arg >> Label("Label 1") >> [op2, op3] # pylint: disable=W0106 op1_arg >> op4 # Check that the DAG has the right edge info assert dag.get_edge_info(op1.task_id, op2.task_id) == { "label": "Label 1" } assert dag.get_edge_info(op1.task_id, op4.task_id) == {}
def test_xcom_pass_to_op(self): with DAG(dag_id="test_xcom_pass_to_op", default_args=DEFAULT_ARGS) as dag: operator = PythonOperator( python_callable=lambda: VALUE, task_id="return_value_1", do_xcom_push=True, ) xarg = XComArg(operator) operator2 = PythonOperator( python_callable=assert_is_value, op_args=[xarg], task_id="assert_is_value_1", ) operator >> operator2 dag.run()
def map( self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs ) -> XComArg: self._validate_arg_names("map", kwargs) dag = dag or DagContext.get_current_dag() task_group = task_group or TaskGroupContext.get_current_task_group(dag) task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group) operator = MappedOperator.from_decorator( decorator=self, dag=dag, task_group=task_group, task_id=task_id, mapped_kwargs=kwargs, ) return XComArg(operator=operator)
def test_xcom_push_and_pass(self): def push_xcom_value(key, value, **context): ti = context["task_instance"] ti.xcom_push(key, value) with DAG(dag_id="test_xcom_push_and_pass", default_args=DEFAULT_ARGS) as dag: op1 = PythonOperator( python_callable=push_xcom_value, task_id="push_xcom_value", op_args=["my_key", VALUE], ) xarg = XComArg(op1, key="my_key") op2 = PythonOperator( python_callable=assert_is_value, task_id="assert_is_value_1", op_args=[xarg], ) op1 >> op2 dag.run()
def map(self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs) -> XComArg: dag = dag or DagContext.get_current_dag() task_group = task_group or TaskGroupContext.get_current_task_group(dag) task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group) self._validate_arg_names("map", kwargs) operator = MappedOperator( operator_class=self.operator_class, task_id=task_id, dag=dag, task_group=task_group, partial_kwargs=self.kwargs, # Set them to empty to bypass the validation, as for decorated stuff we validate ourselves mapped_kwargs={}, ) operator.mapped_kwargs.update(kwargs) return XComArg(operator=operator)
def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG': """Deserializes a DAG from a JSON object.""" dag = SerializedDAG(dag_id=encoded_dag['_dag_id']) for k, v in encoded_dag.items(): if k == "_downstream_task_ids": v = set(v) elif k == "tasks": SerializedBaseOperator._load_operator_extra_links = cls._load_operator_extra_links v = {task["task_id"]: SerializedBaseOperator.deserialize_operator(task) for task in v} k = "task_dict" elif k == "timezone": v = cls._deserialize_timezone(v) elif k == "dagrun_timeout": v = cls._deserialize_timedelta(v) elif k.endswith("_date"): v = cls._deserialize_datetime(v) elif k == "edge_info": # Value structure matches exactly pass elif k == "timetable": v = _decode_timetable(v) elif k in cls._decorated_fields: v = cls._deserialize(v) elif k == "params": v = cls._deserialize_params_dict(v) # else use v as it is setattr(dag, k, v) # A DAG is always serialized with only one of schedule_interval and # timetable. This back-populates the other to ensure the two attributes # line up correctly on the DAG instance. if "timetable" in encoded_dag: dag.schedule_interval = dag.timetable.summary else: dag.timetable = create_timetable(dag.schedule_interval, dag.timezone) # Set _task_group if "_task_group" in encoded_dag: dag._task_group = SerializedTaskGroup.deserialize_task_group( encoded_dag["_task_group"], None, dag.task_dict, dag ) else: # This must be old data that had no task_group. Create a root TaskGroup and add # all tasks to it. dag._task_group = TaskGroup.create_root(dag) for task in dag.tasks: dag.task_group.add(task) # Set has_on_*_callbacks to True if they exist in Serialized blob as False is the default if "has_on_success_callback" in encoded_dag: dag.has_on_success_callback = True if "has_on_failure_callback" in encoded_dag: dag.has_on_failure_callback = True keys_to_set_none = dag.get_serialized_fields() - encoded_dag.keys() - cls._CONSTRUCTOR_PARAMS.keys() for k in keys_to_set_none: setattr(dag, k, None) for task in dag.task_dict.values(): task.dag = dag for date_attr in ["start_date", "end_date"]: if getattr(task, date_attr) is None: setattr(task, date_attr, getattr(dag, date_attr)) if task.subdag is not None: setattr(task.subdag, 'parent_dag', dag) if isinstance(task, MappedOperator): expansion_kwargs = task._get_expansion_kwargs() for k, v in expansion_kwargs.items(): if not isinstance(v, _XComRef): continue expansion_kwargs[k] = XComArg(operator=dag.get_task(v.task_id), key=v.key) for task_id in task.downstream_task_ids: # Bypass set_upstream etc here - it does more than we want dag.task_dict[task_id].upstream_task_ids.add(task.task_id) return dag
def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: ensure_xcomarg_return_value(expand_input.value) task_kwargs = self.kwargs.copy() dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag() task_group = task_kwargs.pop( "task_group", None) or TaskGroupContext.get_current_task_group(dag) partial_kwargs, default_params = get_merged_defaults( dag=dag, task_group=task_group, task_params=task_kwargs.pop("params", None), task_default_args=task_kwargs.pop("default_args", None), ) partial_kwargs.update(task_kwargs) task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group) params = partial_kwargs.pop("params", None) or default_params # Logic here should be kept in sync with BaseOperatorMeta.partial(). if "task_concurrency" in partial_kwargs: raise TypeError("unexpected argument: task_concurrency") if partial_kwargs.get("wait_for_downstream"): partial_kwargs["depends_on_past"] = True start_date = timezone.convert_to_utc( partial_kwargs.pop("start_date", None)) end_date = timezone.convert_to_utc(partial_kwargs.pop( "end_date", None)) if partial_kwargs.get("pool") is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME partial_kwargs["retries"] = parse_retries( partial_kwargs.get("retries", DEFAULT_RETRIES)) partial_kwargs["retry_delay"] = coerce_timedelta( partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY), key="retry_delay", ) max_retry_delay = partial_kwargs.get("max_retry_delay") partial_kwargs["max_retry_delay"] = ( max_retry_delay if max_retry_delay is None else coerce_timedelta( max_retry_delay, key="max_retry_delay")) partial_kwargs["resources"] = coerce_resources( partial_kwargs.get("resources")) partial_kwargs.setdefault("executor_config", {}) partial_kwargs.setdefault("op_args", []) partial_kwargs.setdefault("op_kwargs", {}) # Mypy does not work well with a subclassed attrs class :( _MappedOperator = cast(Any, DecoratedMappedOperator) try: operator_name = self.operator_class.custom_operator_name # type: ignore except AttributeError: operator_name = self.operator_class.__name__ operator = _MappedOperator( operator_class=self.operator_class, expand_input= EXPAND_INPUT_EMPTY, # Don't use this; mapped values go to op_kwargs_expand_input. partial_kwargs=partial_kwargs, task_id=task_id, params=params, deps=MappedOperator.deps_for(self.operator_class), operator_extra_links=self.operator_class.operator_extra_links, template_ext=self.operator_class.template_ext, template_fields=self.operator_class.template_fields, template_fields_renderers=self.operator_class. template_fields_renderers, ui_color=self.operator_class.ui_color, ui_fgcolor=self.operator_class.ui_fgcolor, is_empty=False, task_module=self.operator_class.__module__, task_type=self.operator_class.__name__, operator_name=operator_name, dag=dag, task_group=task_group, start_date=start_date, end_date=end_date, multiple_outputs=self.multiple_outputs, python_callable=self.function, op_kwargs_expand_input=expand_input, disallow_kwargs_override=strict, # Different from classic operators, kwargs passed to a taskflow # task's expand() contribute to the op_kwargs operator argument, not # the operator arguments themselves, and should expand against it. expand_input_attr="op_kwargs_expand_input", ) return XComArg(operator=operator)
def iter_mapped_dependencies(self) -> Iterator["Operator"]: """Upstream dependencies that provide XComs used by this task for task mapping.""" from airflow.models.xcom_arg import XComArg for ref in XComArg.iter_xcom_args(self._get_expansion_kwargs()): yield ref.operator
def test_xcom_key_getitem(self): python_op = build_python_op() actual = XComArg(python_op, key="another_key") assert actual.key == "another_key" actual_new_key = actual["another_key_2"] assert actual_new_key.key == "another_key_2"
def test_xcom_arg_property_of_base_operator(self): with DAG("test_xcom_arg_property_of_base_operator", default_args=DEFAULT_ARGS): op_a = BashOperator(task_id="a", bash_command="echo a") assert op_a.output == XComArg(op_a)
def expand(self, **map_kwargs: "Mappable") -> XComArg: self._validate_arg_names("expand", map_kwargs) prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping already partial") ensure_xcomarg_return_value(map_kwargs) task_kwargs = self.kwargs.copy() dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag() task_group = task_kwargs.pop( "task_group", None) or TaskGroupContext.get_current_task_group(dag) partial_kwargs, default_params = get_merged_defaults( dag=dag, task_group=task_group, task_params=task_kwargs.pop("params", None), task_default_args=task_kwargs.pop("default_args", None), ) partial_kwargs.update(task_kwargs) task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group) params = partial_kwargs.pop("params", None) or default_params # Logic here should be kept in sync with BaseOperatorMeta.partial(). if "task_concurrency" in partial_kwargs: raise TypeError("unexpected argument: task_concurrency") if partial_kwargs.get("wait_for_downstream"): partial_kwargs["depends_on_past"] = True start_date = timezone.convert_to_utc( partial_kwargs.pop("start_date", None)) end_date = timezone.convert_to_utc(partial_kwargs.pop( "end_date", None)) if partial_kwargs.get("pool") is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME partial_kwargs["retries"] = parse_retries( partial_kwargs.get("retries", DEFAULT_RETRIES)) partial_kwargs["retry_delay"] = coerce_retry_delay( partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY), ) partial_kwargs["resources"] = coerce_resources( partial_kwargs.get("resources")) partial_kwargs.setdefault("executor_config", {}) partial_kwargs.setdefault("op_args", []) partial_kwargs.setdefault("op_kwargs", {}) # Mypy does not work well with a subclassed attrs class :( _MappedOperator = cast(Any, DecoratedMappedOperator) operator = _MappedOperator( operator_class=self.operator_class, mapped_kwargs={}, partial_kwargs=partial_kwargs, task_id=task_id, params=params, deps=MappedOperator.deps_for(self.operator_class), operator_extra_links=self.operator_class.operator_extra_links, template_ext=self.operator_class.template_ext, template_fields=self.operator_class.template_fields, template_fields_renderers=self.operator_class. template_fields_renderers, ui_color=self.operator_class.ui_color, ui_fgcolor=self.operator_class.ui_fgcolor, is_empty=False, task_module=self.operator_class.__module__, task_type=self.operator_class.__name__, dag=dag, task_group=task_group, start_date=start_date, end_date=end_date, multiple_outputs=self.multiple_outputs, python_callable=self.function, mapped_op_kwargs=map_kwargs, # Different from classic operators, kwargs passed to a taskflow # task's expand() contribute to the op_kwargs operator argument, not # the operator arguments themselves, and should expand against it. expansion_kwargs_attr="mapped_op_kwargs", ) return XComArg(operator=operator)