def subdag_d(): subdag_d = DAG('nested_cycle.op_subdag_1.opSubdag_D', default_args=default_args) DummyOperator(task_id='subdag_d.task', dag=subdag_d) return subdag_d
# with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Example of the LatestOnlyOperator""" import datetime as dt from airflow import DAG from airflow.operators.dummy import DummyOperator from airflow.operators.latest_only import LatestOnlyOperator from airflow.utils.dates import days_ago with DAG( dag_id='latest_only', schedule_interval=dt.timedelta(hours=4), start_date=days_ago(2), tags=['example2', 'example3'], ) as dag: latest_only = LatestOnlyOperator(task_id='latest_only') task1 = DummyOperator(task_id='task1') latest_only >> task1
from airflow import DAG from airflow.operators.dummy import DummyOperator from airflow.providers.docker.operators.docker import DockerOperator from airflow.utils.dates import days_ago from airflow.sensors.filesystem import FileSensor from utils import default_args, VOLUME with DAG( "2_train_pipeline", default_args=default_args, schedule_interval="@weekly", start_date=days_ago(5), ) as dag: start = DummyOperator(task_id="Begin") data_sensor = FileSensor(task_id="Wait_for_data", poke_interval=10, retries=100, filepath="data/raw/{{ ds }}/data.csv") target_sensor = FileSensor(task_id="Wait_for_target", poke_interval=10, retries=100, filepath="data/raw/{{ ds }}/target.csv") preprocess = DockerOperator( task_id="Data_preprocess", image="airflow-preprocess", command=
def bs_disaster_dag(): @task() def extract_transform(): df = pd.read_csv(f"{DATA_PATH}/disaster_data.csv") columns = ['text', 'location'] for column in columns: df[column] = df[column].str.replace(r'\s{2,}', ' ', regex=True) df[column] = df[column].str.replace(r"[^a-zA-Z0-9\,]", ' ', regex=True) df.to_csv(OUT_PATH, index=False, header=False) start = DummyOperator(task_id='start') end = DummyOperator(task_id='end') extract_transform_task = extract_transform() stored_data_gcs = LocalFilesystemToGCSOperator( task_id="store_to_gcs", gcp_conn_id=GOOGLE_CLOUD_CONN_ID, src=OUT_PATH, dst=GCS_OBJECT_NAME, bucket=BUCKET_NAME) loaded_data_bigquery = GCSToBigQueryOperator( task_id='load_to_bigquery', bigquery_conn_id=GOOGLE_CLOUD_CONN_ID, bucket=BUCKET_NAME, source_objects=[GCS_OBJECT_NAME], destination_project_dataset_table=f"{DATASET_ID}.{BIGQUERY_TABLE_NAME}", schema_fields=[ #based on https://cloud.google.com/bigquery/docs/schemas { 'name': 'id', 'type': 'INT64', 'mode': 'REQUIRED' }, { 'name': 'keyword', 'type': 'STRING', 'mode': 'NULLABLE' }, { 'name': 'location', 'type': 'STRING', 'mode': 'NULLABLE' }, { 'name': 'text', 'type': 'STRING', 'mode': 'NULLABLE' }, { 'name': 'target', 'type': 'INT64', 'mode': 'NULLABLE' }, ], autodetect=False, write_disposition= 'WRITE_TRUNCATE', #If the table already exists - overwrites the table data ) start >> extract_transform_task extract_transform_task >> stored_data_gcs stored_data_gcs >> loaded_data_bigquery loaded_data_bigquery >> end
's3_bucket': Variable.get('s3_bucket'), 'arn': Variable.get('arn'), 'fn': Variable.get('cpv_csvname'), 'working_dir': os.path.dirname(os.path.abspath(__file__)) } with DAG('cpv_from_local_to_redshift', default_args=default_args, description= 'Upload the CPV File from local system to S3 and to Redshift', schedule_interval=None, tags=['dend', 'cpv', 'staging']) as dag: _docs_md_fp = os.path.join(default_args['working_dir'], 'Readme.md') dag.doc_md = open(_docs_md_fp, 'r').read() start_cpv = DummyOperator(task_id='start_cpv') stop_cpv = DummyOperator(task_id='stop_cpv') upload_cpv_to_s3 = S3UploadFromLocal(task_id='Upload_cpv_to_s3s', s3_folder='staging/cpv_attributes/') create_redshift = RedshiftOperator(task_id='create_redshift', sql='schema_cpv_staging_datalake.sql') copy_from_s3 = RedshiftCopyFromS3(task_id='copy_from_s3', s3_folder='staging/cpv_attributes', schema='staging', table='cpv_attributes', format='csv', header=True,
'email_on_retry': False, 'start_date': START_DATE, 'retries': 1, 'retry_delay': timedelta(minutes=1) } dag = DAG(DAG_ID, default_args=default_args, schedule_interval=SCHEDULE_INTERVAL, start_date=START_DATE) if hasattr(dag, 'doc_md'): dag.doc_md = __doc__ if hasattr(dag, 'catchup'): dag.catchup = False start = DummyOperator(task_id='start', dag=dag) log_cleanup = """ echo "Getting Configurations..." BASE_LOG_FOLDER="{{params.directory}}" WORKER_SLEEP_TIME="{{params.sleep_time}}" sleep ${WORKER_SLEEP_TIME}s MAX_LOG_AGE_IN_DAYS='""" + str(DEFAULT_MAX_LOG_AGE_IN_DAYS) + """' ENABLE_DELETE=""" + str("true" if ENABLE_DELETE else "false") + """ echo "Finished Getting Configurations" echo ""
args = { 'owner': 'airflow', } with DAG( dag_id='example_jdbc_operator', default_args=args, schedule_interval='0 0 * * *', start_date=days_ago(2), dagrun_timeout=timedelta(minutes=60), tags=['example'], ) as dag: run_this_last = DummyOperator( task_id='run_this_last', dag=dag, ) # [START howto_operator_jdbc_template] delete_data = JdbcOperator( task_id='delete_data', sql='delete from my_schema.my_table where dt = {{ ds }}', jdbc_conn_id='my_jdbc_connection', autocommit=True, dag=dag, ) # [END howto_operator_jdbc_template] # [START howto_operator_jdbc] insert_data = JdbcOperator( task_id='insert_data',
from datetime import datetime, timedelta from airflow import DAG from airflow.operators.dummy import DummyOperator default_args = { 'owner': 'Airflow', 'start_date': datetime(2021, 3, 22), 'retries': 1, 'retry_delay': timedelta(seconds=30) } with DAG(dag_id='cycle_error', default_args=default_args, schedule_interval='None') as dag: t1 = DummyOperator(task_id='t1') t2 = DummyOperator(task_id='t2') t3 = DummyOperator(task_id='t3') t1 >> t2 >> t3 >> t1
# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from datetime import datetime from airflow.models import DAG from airflow.operators.dummy import DummyOperator for i in range(1, 2): dag = DAG(dag_id=f'test_latest_runs_{i}') task = DummyOperator(task_id='dummy_task', dag=dag, owner='airflow', start_date=datetime(2016, 2, 1))
def test_lineage(self): dag = DAG(dag_id='test_prepare_lineage', start_date=DEFAULT_DATE) f1s = "/tmp/does_not_exist_1-{}" f2s = "/tmp/does_not_exist_2-{}" f3s = "/tmp/does_not_exist_3" file1 = File(f1s.format("{{ execution_date }}")) file2 = File(f2s.format("{{ execution_date }}")) file3 = File(f3s) with dag: op1 = DummyOperator( task_id='leave1', inlets=file1, outlets=[ file2, ], ) op2 = DummyOperator(task_id='leave2') op3 = DummyOperator(task_id='upstream_level_1', inlets=AUTO, outlets=file3) op4 = DummyOperator(task_id='upstream_level_2') op5 = DummyOperator(task_id='upstream_level_3', inlets=["leave1", "upstream_level_1"]) op1.set_downstream(op3) op2.set_downstream(op3) op3.set_downstream(op4) op4.set_downstream(op5) dag.clear() # execution_date is set in the context in order to avoid creating task instances ctx1 = { "ti": TI(task=op1, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE } ctx2 = { "ti": TI(task=op2, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE } ctx3 = { "ti": TI(task=op3, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE } ctx5 = { "ti": TI(task=op5, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE } # prepare with manual inlets and outlets op1.pre_execute(ctx1) assert len(op1.inlets) == 1 assert op1.inlets[0].url == f1s.format(DEFAULT_DATE) assert len(op1.outlets) == 1 assert op1.outlets[0].url == f2s.format(DEFAULT_DATE) # post process with no backend op1.post_execute(ctx1) op2.pre_execute(ctx2) assert len(op2.inlets) == 0 op2.post_execute(ctx2) op3.pre_execute(ctx3) assert len(op3.inlets) == 1 assert op3.inlets[0].url == f2s.format(DEFAULT_DATE) assert op3.outlets[0] == file3 op3.post_execute(ctx3) # skip 4 op5.pre_execute(ctx5) assert len(op5.inlets) == 2 op5.post_execute(ctx5)
class TestBranchDayOfWeekOperator(unittest.TestCase): """ Tests for BranchDayOfWeekOperator """ @classmethod def setUpClass(cls): with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() def setUp(self): self.dag = DAG( "branch_day_of_week_operator_test", start_date=DEFAULT_DATE, schedule_interval=INTERVAL, ) self.branch_1 = DummyOperator(task_id="branch_1", dag=self.dag) self.branch_2 = DummyOperator(task_id="branch_2", dag=self.dag) self.branch_3 = None def tearDown(self): with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() def _assert_task_ids_match_states(self, dr, task_ids_to_states): """Helper that asserts task instances with a given id are in a given state""" tis = dr.get_task_instances() for ti in tis: try: expected_state = task_ids_to_states[ti.task_id] except KeyError: raise ValueError(f'Invalid task id {ti.task_id} found!') else: self.assertEqual( ti.state, expected_state, f"Task {ti.task_id} has state {ti.state} instead of expected {expected_state}", ) @parameterized.expand([ ("with-string", "Monday"), ("with-enum", WeekDay.MONDAY), ("with-enum-set", {WeekDay.MONDAY}), ("with-enum-set-2-items", {WeekDay.MONDAY, WeekDay.FRIDAY}), ("with-string-set", {"Monday"}), ("with-string-set-2-items", {"Monday", "Friday"}), ]) @freeze_time("2021-01-25") # Monday def test_branch_follow_true(self, _, weekday): """Checks if BranchDayOfWeekOperator follows true branch""" print(datetime.datetime.now()) branch_op = BranchDayOfWeekOperator( task_id="make_choice", follow_task_ids_if_true=["branch_1", "branch_2"], follow_task_ids_if_false="branch_3", week_day=weekday, dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.branch_3 = DummyOperator(task_id="branch_3", dag=self.dag) self.branch_3.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self._assert_task_ids_match_states( dr, { 'make_choice': State.SUCCESS, 'branch_1': State.NONE, 'branch_2': State.NONE, 'branch_3': State.SKIPPED, }, ) @freeze_time("2021-01-25") # Monday def test_branch_follow_true_with_execution_date(self): """Checks if BranchDayOfWeekOperator follows true branch when set use_task_execution_day """ branch_op = BranchDayOfWeekOperator( task_id="make_choice", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", week_day="Wednesday", use_task_execution_day= True, # We compare to DEFAULT_DATE which is Wednesday dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self._assert_task_ids_match_states( dr, { 'make_choice': State.SUCCESS, 'branch_1': State.NONE, 'branch_2': State.SKIPPED, }, ) @freeze_time("2021-01-25") # Monday def test_branch_follow_false(self): """Checks if BranchDayOfWeekOperator follow false branch""" branch_op = BranchDayOfWeekOperator( task_id="make_choice", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", week_day="Sunday", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self._assert_task_ids_match_states( dr, { 'make_choice': State.SUCCESS, 'branch_1': State.SKIPPED, 'branch_2': State.NONE, }, ) def test_branch_with_no_weekday(self): """Check if BranchDayOfWeekOperator raises exception on missing weekday""" with self.assertRaises(AirflowException): BranchDayOfWeekOperator( # pylint: disable=missing-kwoa task_id="make_choice", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) def test_branch_with_invalid_type(self): """Check if BranchDayOfWeekOperator raises exception on unsupported weekday type""" invalid_week_day = ['Monday'] with pytest.raises( TypeError, match='Unsupported Type for week_day parameter:' ' {}. It should be one of str, set or ' 'Weekday enum type'.format(type(invalid_week_day)), ): BranchDayOfWeekOperator( task_id="make_choice", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", week_day=invalid_week_day, dag=self.dag, ) def test_weekday_branch_invalid_weekday_number(self): """Check if BranchDayOfWeekOperator raises exception on wrong value of weekday""" invalid_week_day = 'Thsday' with pytest.raises( AttributeError, match=f'Invalid Week Day passed: "{invalid_week_day}"'): BranchDayOfWeekOperator( task_id="make_choice", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", week_day=invalid_week_day, dag=self.dag, ) @freeze_time("2021-01-25") # Monday def test_branch_xcom_push_true_branch(self): """Check if BranchDayOfWeekOperator push to xcom value of follow_task_ids_if_true""" branch_op = BranchDayOfWeekOperator( task_id="make_choice", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", week_day="Monday", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': assert ti.xcom_pull(task_ids='make_choice') == 'branch_1'
'execution_timeout': timedelta(seconds=300), 'working_dir': os.path.dirname(os.path.abspath(__file__)) } with DAG( 'cpv_dwh_transformations', default_args=default_args, description='Refresh CPV data in the Data Warehouse layer (DWH)', schedule_interval=None, tags=['dend', 'cpv', 'dwh'] ) as dag: _docs_md_fp = os.path.join(default_args['working_dir'], 'Readme.md') dag.doc_md = open(_docs_md_fp, 'r').read() start_refresh = DummyOperator( task_id='start_refresh' ) create_cpv_dwh = RedshiftOperator( task_id='create_cpv_dwh', sql='schema_cpv_dwh.sql' ) refresh_cpv_dwh = RedshiftOperator( task_id='refresh_cpv_dwh', sql='refresh_cpv.sql' ) q_check = RedshiftQualityCheck( task_id='quality_check', schema="dwh", table="cpv_attributes",
def _fetch_dataset_old(): print("Fetching data (OLD)...") def _fetch_dataset_new(): print("Fetching data (NEW)...") with DAG( dag_id="03_branching", start_date=airflow.utils.dates.days_ago(3), schedule_interval="@daily", ) as dag: start = DummyOperator(task_id="start") pick_branch = BranchPythonOperator(task_id="pick_branch", python_callable=_pick_branch) fetch_dataset_old = PythonOperator(task_id="fetch_dataset_old", python_callable=_fetch_dataset_old) fetch_dataset_new = PythonOperator(task_id="fetch_dataset_new", python_callable=_fetch_dataset_new) fetch_another_dataset = DummyOperator(task_id="fetch_another_dataset") join_datasets = DummyOperator(task_id="join_datasets", trigger_rule="none_failed")
from airflow.utils.dates import days_ago DAG_NAME = 'example_subdag_operator' args = { 'owner': 'airflow', } dag = DAG(dag_id=DAG_NAME, default_args=args, start_date=days_ago(2), schedule_interval="@once", tags=['example']) start = DummyOperator( task_id='start', dag=dag, ) section_1 = SubDagOperator( task_id='section-1', subdag=subdag(DAG_NAME, 'section-1', args), dag=dag, ) some_other_task = DummyOperator( task_id='some-other-task', dag=dag, ) section_2 = SubDagOperator( task_id='section-2',
from airflow import DAG from airflow.operators.dummy import DummyOperator from airflow.providers.docker.operators.docker import DockerOperator from airflow.utils.dates import days_ago from airflow.models import Variable from airflow.sensors.filesystem import FileSensor from utils import DEFAULT_VOLUME, default_args with DAG("DAG3_inference", default_args=default_args, schedule_interval="@daily", start_date=days_ago(3)) as dag: start_task = DummyOperator(task_id='start-prediction') data_await = FileSensor( filepath='/opt/airflow/data/raw/{{ ds }}/data.csv', task_id="await-data", poke_interval=10, retries=100, ) model_await = FileSensor( filepath='/opt/airflow/{{ var.value.model_dir }}/model.pkl', task_id="await-model", poke_interval=10, retries=100, ) preprocessing = DockerOperator( task_id="preprocessing",
start_date=datetime(2021, 8, 13), schedule_interval="@daily", catchup=False, default_args= { "retries": 1, "retry_delay": timedelta(minutes=3), "azure_data_factory_conn_id": "azure_data_factory", "factory_name": "my-data-factory", # This can also be specified in the ADF connection. "resource_group_name": "my-resource-group", # This can also be specified in the ADF connection. }, default_view="graph", ) as dag: begin = DummyOperator(task_id="begin") end = DummyOperator(task_id="end") # [START howto_operator_adf_run_pipeline] run_pipeline1 = AzureDataFactoryRunPipelineOperator( task_id="run_pipeline1", pipeline_name="pipeline1", parameters={"myParam": "value"}, ) # [END howto_operator_adf_run_pipeline] # [START howto_operator_adf_run_pipeline_async] run_pipeline2 = AzureDataFactoryRunPipelineOperator( task_id="run_pipeline2", pipeline_name="pipeline2", wait_for_termination=False,
"email_on_retry": False, "retries": 1, "retry_delay": timedelta(minutes=5), } with DAG( dag_id="covid_data_dag", default_args=default_args, description= "DAG to update Covid 19 data daily to push to a Postgres database.", schedule_interval='30 9 * * *', start_date=datetime(2021, 8, 24), ) as dag: # Initiate tasks task_1 = DummyOperator(task_id="Initiate_DAG") task_2 = PythonOperator( task_id="dashboard_update", python_callable=covid_19_dashboard_update, op_kwargs={ "username": config.username, "password": passwords_dict.get('postgres_password'), "database": config.database, "table_name": config.table_name, "columns": config.columns, "geo_ids_url": config.geo_ids_url, }, ) task_1 >> task_2
def test_bad_trigger_rule(self): with self.assertRaises(AirflowException): DummyOperator(task_id='test_bad_trigger', trigger_rule="non_existent", dag=self.dag)
'depends_on_past': False, "retries": 3, "retry_delay": timedelta(seconds=5) } DAYS = 1 START_DATE = datetime(2021, 3, 28, 0, 0, 0) with DAG(dag_id=DAG_ID, default_args=default_args, description="test hourly dags", schedule_interval="@hourly", start_date=START_DATE, tags=[WORKFLOW_ID], user_defined_macros=default_args) as dag: start = DummyOperator(task_id="start") command = ("echo prev_ds :{{ prev_ds }}, " "ds :{{ ds }}, " "next_ds :{{ next_ds }}, " "next_execution_date : {{ next_execution_date }}") t1 = BashOperator(task_id='t1', bash_command=command) t2 = BashOperator(task_id='t2', bash_command=command) t3 = BashOperator(task_id='t3', bash_command=command) end = DummyOperator(task_id="end") start >> t1 >> t2 >> t3 >> end
def test_build_task_group_with_prefix(): """ Tests that prefix_group_id turns on/off prefixing of task_id with group_id. """ execution_date = pendulum.parse("20200101") with DAG("test_build_task_group_with_prefix", start_date=execution_date) as dag: task1 = DummyOperator(task_id="task1") with TaskGroup("group234", prefix_group_id=False) as group234: task2 = DummyOperator(task_id="task2") with TaskGroup("group34") as group34: task3 = DummyOperator(task_id="task3") with TaskGroup("group4", prefix_group_id=False) as group4: task4 = DummyOperator(task_id="task4") task5 = DummyOperator(task_id="task5") task1 >> group234 group34 >> task5 assert task2.task_id == "task2" assert group34.group_id == "group34" assert task3.task_id == "group34.task3" assert group4.group_id == "group34.group4" assert task4.task_id == "task4" assert task5.task_id == "task5" assert group234.get_child_by_label("task2") == task2 assert group234.get_child_by_label("group34") == group34 assert group4.get_child_by_label("task4") == task4 assert extract_node_id( task_group_to_dict(dag.task_group), include_label=True) == { 'id': None, 'label': None, 'children': [ { 'id': 'group234', 'label': 'group234', 'children': [ { 'id': 'group34', 'label': 'group34', 'children': [ { 'id': 'group34.group4', 'label': 'group4', 'children': [{ 'id': 'task4', 'label': 'task4' }], }, { 'id': 'group34.task3', 'label': 'task3' }, { 'id': 'group34.downstream_join_id', 'label': '' }, ], }, { 'id': 'task2', 'label': 'task2' }, { 'id': 'group234.upstream_join_id', 'label': '' }, ], }, { 'id': 'task1', 'label': 'task1' }, { 'id': 'task5', 'label': 'task5' }, ], }
``none_failed_or_skipped`` trigger rule such that they are skipped whenever their corresponding ``BranchPythonOperator`` are skipped. """ from airflow.models import DAG from airflow.operators.dummy import DummyOperator from airflow.operators.python import BranchPythonOperator from airflow.utils.dates import days_ago with DAG(dag_id="example_nested_branch_dag", start_date=days_ago(2), schedule_interval="@daily", tags=["example"]) as dag: branch_1 = BranchPythonOperator(task_id="branch_1", python_callable=lambda: "true_1") join_1 = DummyOperator(task_id="join_1", trigger_rule="none_failed_or_skipped") true_1 = DummyOperator(task_id="true_1") false_1 = DummyOperator(task_id="false_1") branch_2 = BranchPythonOperator(task_id="branch_2", python_callable=lambda: "true_2") join_2 = DummyOperator(task_id="join_2", trigger_rule="none_failed_or_skipped") true_2 = DummyOperator(task_id="true_2") false_2 = DummyOperator(task_id="false_2") false_3 = DummyOperator(task_id="false_3") branch_1 >> true_1 >> join_1 branch_1 >> false_1 >> branch_2 >> [true_2, false_2 ] >> join_2 >> false_3 >> join_1
def test_sub_dag_task_group(): """ Tests dag.sub_dag() updates task_group correctly. """ execution_date = pendulum.parse("20200101") with DAG("test_test_task_group_sub_dag", start_date=execution_date) as dag: task1 = DummyOperator(task_id="task1") with TaskGroup("group234") as group234: _ = DummyOperator(task_id="task2") with TaskGroup("group34") as group34: _ = DummyOperator(task_id="task3") _ = DummyOperator(task_id="task4") with TaskGroup("group6") as group6: _ = DummyOperator(task_id="task6") task7 = DummyOperator(task_id="task7") task5 = DummyOperator(task_id="task5") task1 >> group234 group34 >> task5 group234 >> group6 group234 >> task7 subdag = dag.sub_dag(task_ids_or_regex="task5", include_upstream=True, include_downstream=False) assert extract_node_id(task_group_to_dict(subdag.task_group)) == { 'id': None, 'children': [ { 'id': 'group234', 'children': [ { 'id': 'group234.group34', 'children': [ { 'id': 'group234.group34.task3' }, { 'id': 'group234.group34.task4' }, { 'id': 'group234.group34.downstream_join_id' }, ], }, { 'id': 'group234.upstream_join_id' }, ], }, { 'id': 'task1' }, { 'id': 'task5' }, ], } edges = dag_edges(subdag) assert sorted((e["source_id"], e["target_id"]) for e in edges) == [ ('group234.group34.downstream_join_id', 'task5'), ('group234.group34.task3', 'group234.group34.downstream_join_id'), ('group234.group34.task4', 'group234.group34.downstream_join_id'), ('group234.upstream_join_id', 'group234.group34.task3'), ('group234.upstream_join_id', 'group234.group34.task4'), ('task1', 'group234.upstream_join_id'), ] subdag_task_groups = subdag.task_group.get_task_group_dict() assert subdag_task_groups.keys() == {None, "group234", "group234.group34"} included_group_ids = {"group234", "group234.group34"} included_task_ids = { 'group234.group34.task3', 'group234.group34.task4', 'task1', 'task5' } for task_group in subdag_task_groups.values(): assert task_group.upstream_group_ids.issubset(included_group_ids) assert task_group.downstream_group_ids.issubset(included_group_ids) assert task_group.upstream_task_ids.issubset(included_task_ids) assert task_group.downstream_task_ids.issubset(included_task_ids) for task in subdag.task_group: assert task.upstream_task_ids.issubset(included_task_ids) assert task.downstream_task_ids.issubset(included_task_ids)
import datetime as dt from airflow import DAG from airflow.operators.dummy import DummyOperator from airflow.operators.latest_only import LatestOnlyOperator from airflow.utils.dates import days_ago from airflow.utils.trigger_rule import TriggerRule """ 옛날 task 는 호출 안되고, 가장 최신 task만 호출됨. backfill이 안된다고 보면 될듯. task2는 스킵되지 않았는데, task3,4는 스킵됨. latest_only와 연결된 task1,3,4가 모두 스킵되서, task2가 호출되어도 task3,4는 호출 안되는듯 task4의 TriggerRule.ALL_DONE으로 하면, 모든게 다 처리되고, 앞에 것이 스킵 되어도, 수행됨. """ dag = DAG( dag_id='latest_only_operator_test', schedule_interval=dt.timedelta(hours=4), start_date=days_ago(2), tags=['example3'], ) latest_only = LatestOnlyOperator(task_id='latest_only', dag=dag) task1 = DummyOperator(task_id='task1', dag=dag) task2 = DummyOperator(task_id='task2', dag=dag) task3 = DummyOperator(task_id='task3', dag=dag) task4 = DummyOperator(task_id='task4', dag=dag) # task4 = DummyOperator(task_id='task4', dag=dag, trigger_rule=TriggerRule.ALL_DONE) latest_only >> task1 >> [task3, task4] task2 >> [task3, task4]
def test_dag_edges(): execution_date = pendulum.parse("20200101") with DAG("test_dag_edges", start_date=execution_date) as dag: task1 = DummyOperator(task_id="task1") with TaskGroup("group_a") as group_a: with TaskGroup("group_b") as group_b: task2 = DummyOperator(task_id="task2") task3 = DummyOperator(task_id="task3") task4 = DummyOperator(task_id="task4") task2 >> [task3, task4] task5 = DummyOperator(task_id="task5") task5 << group_b task1 >> group_a with TaskGroup("group_c") as group_c: task6 = DummyOperator(task_id="task6") task7 = DummyOperator(task_id="task7") task8 = DummyOperator(task_id="task8") [task6, task7] >> task8 group_a >> group_c task5 >> task8 task9 = DummyOperator(task_id="task9") task10 = DummyOperator(task_id="task10") group_c >> [task9, task10] with TaskGroup("group_d") as group_d: task11 = DummyOperator(task_id="task11") task12 = DummyOperator(task_id="task12") task11 >> task12 group_d << group_c nodes = task_group_to_dict(dag.task_group) edges = dag_edges(dag) assert extract_node_id(nodes) == { 'id': None, 'children': [ { 'id': 'group_a', 'children': [ { 'id': 'group_a.group_b', 'children': [ { 'id': 'group_a.group_b.task2' }, { 'id': 'group_a.group_b.task3' }, { 'id': 'group_a.group_b.task4' }, { 'id': 'group_a.group_b.downstream_join_id' }, ], }, { 'id': 'group_a.task5' }, { 'id': 'group_a.upstream_join_id' }, { 'id': 'group_a.downstream_join_id' }, ], }, { 'id': 'group_c', 'children': [ { 'id': 'group_c.task6' }, { 'id': 'group_c.task7' }, { 'id': 'group_c.task8' }, { 'id': 'group_c.upstream_join_id' }, { 'id': 'group_c.downstream_join_id' }, ], }, { 'id': 'group_d', 'children': [ { 'id': 'group_d.task11' }, { 'id': 'group_d.task12' }, { 'id': 'group_d.upstream_join_id' }, ], }, { 'id': 'task1' }, { 'id': 'task10' }, { 'id': 'task9' }, ], } assert sorted((e["source_id"], e["target_id"]) for e in edges) == [ ('group_a.downstream_join_id', 'group_c.upstream_join_id'), ('group_a.group_b.downstream_join_id', 'group_a.task5'), ('group_a.group_b.task2', 'group_a.group_b.task3'), ('group_a.group_b.task2', 'group_a.group_b.task4'), ('group_a.group_b.task3', 'group_a.group_b.downstream_join_id'), ('group_a.group_b.task4', 'group_a.group_b.downstream_join_id'), ('group_a.task5', 'group_a.downstream_join_id'), ('group_a.task5', 'group_c.task8'), ('group_a.upstream_join_id', 'group_a.group_b.task2'), ('group_c.downstream_join_id', 'group_d.upstream_join_id'), ('group_c.downstream_join_id', 'task10'), ('group_c.downstream_join_id', 'task9'), ('group_c.task6', 'group_c.task8'), ('group_c.task7', 'group_c.task8'), ('group_c.task8', 'group_c.downstream_join_id'), ('group_c.upstream_join_id', 'group_c.task6'), ('group_c.upstream_join_id', 'group_c.task7'), ('group_d.task11', 'group_d.task12'), ('group_d.upstream_join_id', 'group_d.task11'), ('task1', 'group_a.upstream_join_id'), ]
from airflow.utils.dates import days_ago args = { "owner": "airflow", } dag = DAG( dag_id="example_branch_datetime_operator", start_date=days_ago(2), default_args=args, tags=["example"], schedule_interval="@daily", ) # [START howto_branch_datetime_operator] dummy_task_1 = DummyOperator(task_id='date_in_range', dag=dag) dummy_task_2 = DummyOperator(task_id='date_outside_range', dag=dag) cond1 = BranchDateTimeOperator( task_id='datetime_branch', follow_task_ids_if_true=['date_in_range'], follow_task_ids_if_false=['date_outside_range'], target_upper=datetime.datetime(2020, 10, 10, 15, 0, 0), target_lower=datetime.datetime(2020, 10, 10, 14, 0, 0), dag=dag, ) # Run dummy_task_1 if cond1 executes between 2020-10-10 14:00:00 and 2020-10-10 15:00:00 cond1 >> [dummy_task_1, dummy_task_2] # [END howto_branch_datetime_operator]
return 'in_accurate' with DAG('xcom_dag', schedule_interval='@daily', default_args=default_args, catchup=False) as dag: downloading_data = BashOperator(task_id='downloading_data', bash_command='sleep 3', do_xcom_push=False) with TaskGroup('processing_tasks') as processing_tasks: training_model_a = PythonOperator(task_id='training_model_a', python_callable=_training_model) training_model_b = PythonOperator(task_id='training_model_b', python_callable=_training_model) training_model_c = PythonOperator(task_id='training_model_c', python_callable=_training_model) choose_model = BranchPythonOperator(task_id='task_4', python_callable=_choose_best_model) accurate = DummyOperator(task_id='accurate') in_accurate = DummyOperator(task_id='in_accurate') downloading_data >> processing_tasks >> choose_model choose_model >> [accurate, in_accurate]
def _latest_only(**context): now = pendulum.now("UTC") left_window = context["dag"].following_schedule(context["execution_date"]) right_window = context["dag"].following_schedule(left_window) if not left_window < now <= right_window: raise AirflowSkipException() with DAG( dag_id="06_condition_dag", start_date=airflow.utils.dates.days_ago(3), schedule_interval="@daily", ) as dag: start = DummyOperator(task_id="start") pick_erp = BranchPythonOperator( task_id="pick_erp_system", python_callable=_pick_erp_system ) fetch_sales_old = DummyOperator(task_id="fetch_sales_old") clean_sales_old = DummyOperator(task_id="clean_sales_old") fetch_sales_new = DummyOperator(task_id="fetch_sales_new") clean_sales_new = DummyOperator(task_id="clean_sales_new") join_erp = DummyOperator(task_id="join_erp_branch", trigger_rule="none_failed") fetch_weather = DummyOperator(task_id="fetch_weather") clean_weather = DummyOperator(task_id="clean_weather")
class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): """ Test for SQL Branch Operator """ @classmethod def setUpClass(cls): super().setUpClass() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() def setUp(self): super().setUp() self.dag = DAG( "sql_branch_operator_test", default_args={ "owner": "airflow", "start_date": DEFAULT_DATE }, schedule_interval=INTERVAL, ) self.branch_1 = DummyOperator(task_id="branch_1", dag=self.dag) self.branch_2 = DummyOperator(task_id="branch_2", dag=self.dag) self.branch_3 = None def tearDown(self): super().tearDown() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() def test_unsupported_conn_type(self): """Check if BranchSQLOperator throws an exception for unsupported connection type""" op = BranchSQLOperator( task_id="make_choice", conn_id="redis_default", sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) with pytest.raises(AirflowException): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_invalid_conn(self): """Check if BranchSQLOperator throws an exception for invalid connection""" op = BranchSQLOperator( task_id="make_choice", conn_id="invalid_connection", sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) with pytest.raises(AirflowException): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_invalid_follow_task_true(self): """Check if BranchSQLOperator throws an exception for invalid connection""" op = BranchSQLOperator( task_id="make_choice", conn_id="invalid_connection", sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", follow_task_ids_if_true=None, follow_task_ids_if_false="branch_2", dag=self.dag, ) with pytest.raises(AirflowException): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_invalid_follow_task_false(self): """Check if BranchSQLOperator throws an exception for invalid connection""" op = BranchSQLOperator( task_id="make_choice", conn_id="invalid_connection", sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", follow_task_ids_if_true="branch_1", follow_task_ids_if_false=None, dag=self.dag, ) with pytest.raises(AirflowException): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @pytest.mark.backend("mysql") def test_sql_branch_operator_mysql(self): """Check if BranchSQLOperator works with backend""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @pytest.mark.backend("postgres") def test_sql_branch_operator_postgres(self): """Check if BranchSQLOperator works with backend""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="postgres_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook") def test_branch_single_value_with_dag_run(self, mock_get_db_hook): """Check BranchSQLOperator branch operation""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_get_records = mock_get_db_hook.return_value.get_first mock_get_records.return_value = 1 branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": assert ti.state == State.SUCCESS elif ti.task_id == "branch_1": assert ti.state == State.NONE elif ti.task_id == "branch_2": assert ti.state == State.SKIPPED else: raise ValueError(f"Invalid task id {ti.task_id} found!") @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook") def test_branch_true_with_dag_run(self, mock_get_db_hook): """Check BranchSQLOperator branch operation""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_get_records = mock_get_db_hook.return_value.get_first for true_value in SUPPORTED_TRUE_VALUES: mock_get_records.return_value = true_value branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": assert ti.state == State.SUCCESS elif ti.task_id == "branch_1": assert ti.state == State.NONE elif ti.task_id == "branch_2": assert ti.state == State.SKIPPED else: raise ValueError(f"Invalid task id {ti.task_id} found!") @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook") def test_branch_false_with_dag_run(self, mock_get_db_hook): """Check BranchSQLOperator branch operation""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_get_records = mock_get_db_hook.return_value.get_first for false_value in SUPPORTED_FALSE_VALUES: mock_get_records.return_value = false_value branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": assert ti.state == State.SUCCESS elif ti.task_id == "branch_1": assert ti.state == State.SKIPPED elif ti.task_id == "branch_2": assert ti.state == State.NONE else: raise ValueError(f"Invalid task id {ti.task_id} found!") @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook") def test_branch_list_with_dag_run(self, mock_get_db_hook): """Checks if the BranchSQLOperator supports branching off to a list of tasks.""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true=["branch_1", "branch_2"], follow_task_ids_if_false="branch_3", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.branch_3 = DummyOperator(task_id="branch_3", dag=self.dag) self.branch_3.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_get_records = mock_get_db_hook.return_value.get_first mock_get_records.return_value = [["1"]] branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": assert ti.state == State.SUCCESS elif ti.task_id == "branch_1": assert ti.state == State.NONE elif ti.task_id == "branch_2": assert ti.state == State.NONE elif ti.task_id == "branch_3": assert ti.state == State.SKIPPED else: raise ValueError(f"Invalid task id {ti.task_id} found!") @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook") def test_invalid_query_result_with_dag_run(self, mock_get_db_hook): """Check BranchSQLOperator branch operation""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_get_records = mock_get_db_hook.return_value.get_first mock_get_records.return_value = ["Invalid Value"] with pytest.raises(AirflowException): branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook") def test_with_skip_in_branch_downstream_dependencies( self, mock_get_db_hook): """Test SQL Branch with skipping all downstream dependencies""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) branch_op >> self.branch_1 >> self.branch_2 branch_op >> self.branch_2 self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_get_records = mock_get_db_hook.return_value.get_first for true_value in SUPPORTED_TRUE_VALUES: mock_get_records.return_value = [true_value] branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": assert ti.state == State.SUCCESS elif ti.task_id == "branch_1": assert ti.state == State.NONE elif ti.task_id == "branch_2": assert ti.state == State.NONE else: raise ValueError(f"Invalid task id {ti.task_id} found!") @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook") def test_with_skip_in_branch_downstream_dependencies2( self, mock_get_db_hook): """Test skipping downstream dependency for false condition""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) branch_op >> self.branch_1 >> self.branch_2 branch_op >> self.branch_2 self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_get_records = mock_get_db_hook.return_value.get_first for false_value in SUPPORTED_FALSE_VALUES: mock_get_records.return_value = [false_value] branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": assert ti.state == State.SUCCESS elif ti.task_id == "branch_1": assert ti.state == State.SKIPPED elif ti.task_id == "branch_2": assert ti.state == State.NONE else: raise ValueError(f"Invalid task id {ti.task_id} found!")
def _wait_for_supermarket(supermarket_id_): supermarket_path = Path("/data/" + supermarket_id_) data_files = supermarket_path.glob("data-*.csv") success_file = supermarket_path / "_SUCCESS" return data_files and success_file.exists() for supermarket_id in range(1, 5): wait = PythonSensor( task_id=f"wait_for_supermarket_{supermarket_id}", python_callable=_wait_for_supermarket, op_kwargs={"supermarket_id_": f"supermarket{supermarket_id}"}, dag=dag1, ) copy = DummyOperator(task_id=f"copy_to_raw_supermarket_{supermarket_id}", dag=dag1) process = DummyOperator(task_id=f"process_supermarket_{supermarket_id}", dag=dag1) trigger_create_metrics_dag = TriggerDagRunOperator( task_id=f"trigger_create_metrics_dag_supermarket_{supermarket_id}", trigger_dag_id="listing_6_04_dag02", dag=dag1, ) wait >> copy >> process >> trigger_create_metrics_dag compute_differences = DummyOperator(task_id="compute_differences", dag=dag2) update_dashboard = DummyOperator(task_id="update_dashboard", dag=dag2) notify_new_data = DummyOperator(task_id="notify_new_data", dag=dag2) compute_differences >> update_dashboard
def subdag_b(): subdag_b = DAG('nested_cycle.op_subdag_0.opSubdag_B', default_args=default_args) DummyOperator(task_id='subdag_b.task', dag=subdag_b) return subdag_b