Exemple #1
0
 def subdag_d():
     subdag_d = DAG('nested_cycle.op_subdag_1.opSubdag_D',
                    default_args=default_args)
     DummyOperator(task_id='subdag_d.task', dag=subdag_d)
     return subdag_d
Exemple #2
0
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Example of the LatestOnlyOperator"""

import datetime as dt

from airflow import DAG
from airflow.operators.dummy import DummyOperator
from airflow.operators.latest_only import LatestOnlyOperator
from airflow.utils.dates import days_ago

with DAG(
        dag_id='latest_only',
        schedule_interval=dt.timedelta(hours=4),
        start_date=days_ago(2),
        tags=['example2', 'example3'],
) as dag:

    latest_only = LatestOnlyOperator(task_id='latest_only')
    task1 = DummyOperator(task_id='task1')

    latest_only >> task1
from airflow import DAG
from airflow.operators.dummy import DummyOperator
from airflow.providers.docker.operators.docker import DockerOperator
from airflow.utils.dates import days_ago
from airflow.sensors.filesystem import FileSensor

from utils import default_args, VOLUME

with DAG(
        "2_train_pipeline",
        default_args=default_args,
        schedule_interval="@weekly",
        start_date=days_ago(5),
) as dag:

    start = DummyOperator(task_id="Begin")

    data_sensor = FileSensor(task_id="Wait_for_data",
                             poke_interval=10,
                             retries=100,
                             filepath="data/raw/{{ ds }}/data.csv")

    target_sensor = FileSensor(task_id="Wait_for_target",
                               poke_interval=10,
                               retries=100,
                               filepath="data/raw/{{ ds }}/target.csv")

    preprocess = DockerOperator(
        task_id="Data_preprocess",
        image="airflow-preprocess",
        command=
Exemple #4
0
def bs_disaster_dag():
    @task()
    def extract_transform():
        df = pd.read_csv(f"{DATA_PATH}/disaster_data.csv")
        columns = ['text', 'location']
        for column in columns:
            df[column] = df[column].str.replace(r'\s{2,}', ' ', regex=True)
            df[column] = df[column].str.replace(r"[^a-zA-Z0-9\,]",
                                                ' ',
                                                regex=True)

        df.to_csv(OUT_PATH, index=False, header=False)

    start = DummyOperator(task_id='start')
    end = DummyOperator(task_id='end')
    extract_transform_task = extract_transform()

    stored_data_gcs = LocalFilesystemToGCSOperator(
        task_id="store_to_gcs",
        gcp_conn_id=GOOGLE_CLOUD_CONN_ID,
        src=OUT_PATH,
        dst=GCS_OBJECT_NAME,
        bucket=BUCKET_NAME)

    loaded_data_bigquery = GCSToBigQueryOperator(
        task_id='load_to_bigquery',
        bigquery_conn_id=GOOGLE_CLOUD_CONN_ID,
        bucket=BUCKET_NAME,
        source_objects=[GCS_OBJECT_NAME],
        destination_project_dataset_table=f"{DATASET_ID}.{BIGQUERY_TABLE_NAME}",
        schema_fields=[  #based on https://cloud.google.com/bigquery/docs/schemas
            {
                'name': 'id',
                'type': 'INT64',
                'mode': 'REQUIRED'
            },
            {
                'name': 'keyword',
                'type': 'STRING',
                'mode': 'NULLABLE'
            },
            {
                'name': 'location',
                'type': 'STRING',
                'mode': 'NULLABLE'
            },
            {
                'name': 'text',
                'type': 'STRING',
                'mode': 'NULLABLE'
            },
            {
                'name': 'target',
                'type': 'INT64',
                'mode': 'NULLABLE'
            },
        ],
        autodetect=False,
        write_disposition=
        'WRITE_TRUNCATE',  #If the table already exists - overwrites the table data
    )

    start >> extract_transform_task
    extract_transform_task >> stored_data_gcs
    stored_data_gcs >> loaded_data_bigquery
    loaded_data_bigquery >> end
Exemple #5
0
    's3_bucket': Variable.get('s3_bucket'),
    'arn': Variable.get('arn'),
    'fn': Variable.get('cpv_csvname'),
    'working_dir': os.path.dirname(os.path.abspath(__file__))
}

with DAG('cpv_from_local_to_redshift',
         default_args=default_args,
         description=
         'Upload the CPV File from local system to S3 and to Redshift',
         schedule_interval=None,
         tags=['dend', 'cpv', 'staging']) as dag:
    _docs_md_fp = os.path.join(default_args['working_dir'], 'Readme.md')
    dag.doc_md = open(_docs_md_fp, 'r').read()

    start_cpv = DummyOperator(task_id='start_cpv')

    stop_cpv = DummyOperator(task_id='stop_cpv')

    upload_cpv_to_s3 = S3UploadFromLocal(task_id='Upload_cpv_to_s3s',
                                         s3_folder='staging/cpv_attributes/')

    create_redshift = RedshiftOperator(task_id='create_redshift',
                                       sql='schema_cpv_staging_datalake.sql')

    copy_from_s3 = RedshiftCopyFromS3(task_id='copy_from_s3',
                                      s3_folder='staging/cpv_attributes',
                                      schema='staging',
                                      table='cpv_attributes',
                                      format='csv',
                                      header=True,
    'email_on_retry': False,
    'start_date': START_DATE,
    'retries': 1,
    'retry_delay': timedelta(minutes=1)
}

dag = DAG(DAG_ID,
          default_args=default_args,
          schedule_interval=SCHEDULE_INTERVAL,
          start_date=START_DATE)
if hasattr(dag, 'doc_md'):
    dag.doc_md = __doc__
if hasattr(dag, 'catchup'):
    dag.catchup = False

start = DummyOperator(task_id='start', dag=dag)

log_cleanup = """

echo "Getting Configurations..."
BASE_LOG_FOLDER="{{params.directory}}"
WORKER_SLEEP_TIME="{{params.sleep_time}}"

sleep ${WORKER_SLEEP_TIME}s

MAX_LOG_AGE_IN_DAYS='""" + str(DEFAULT_MAX_LOG_AGE_IN_DAYS) + """'

ENABLE_DELETE=""" + str("true" if ENABLE_DELETE else "false") + """
echo "Finished Getting Configurations"
echo ""
args = {
    'owner': 'airflow',
}

with DAG(
        dag_id='example_jdbc_operator',
        default_args=args,
        schedule_interval='0 0 * * *',
        start_date=days_ago(2),
        dagrun_timeout=timedelta(minutes=60),
        tags=['example'],
) as dag:

    run_this_last = DummyOperator(
        task_id='run_this_last',
        dag=dag,
    )

    # [START howto_operator_jdbc_template]
    delete_data = JdbcOperator(
        task_id='delete_data',
        sql='delete from my_schema.my_table where dt = {{ ds }}',
        jdbc_conn_id='my_jdbc_connection',
        autocommit=True,
        dag=dag,
    )
    # [END howto_operator_jdbc_template]

    # [START howto_operator_jdbc]
    insert_data = JdbcOperator(
        task_id='insert_data',
Exemple #8
0
from datetime import datetime, timedelta

from airflow import DAG
from airflow.operators.dummy import DummyOperator

default_args = {
    'owner': 'Airflow',
    'start_date': datetime(2021, 3, 22),
    'retries': 1,
    'retry_delay': timedelta(seconds=30)
}

with DAG(dag_id='cycle_error',
         default_args=default_args,
         schedule_interval='None') as dag:
    t1 = DummyOperator(task_id='t1')
    t2 = DummyOperator(task_id='t2')
    t3 = DummyOperator(task_id='t3')

    t1 >> t2 >> t3 >> t1
Exemple #9
0
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.


from datetime import datetime

from airflow.models import DAG
from airflow.operators.dummy import DummyOperator

for i in range(1, 2):
    dag = DAG(dag_id=f'test_latest_runs_{i}')
    task = DummyOperator(task_id='dummy_task', dag=dag, owner='airflow', start_date=datetime(2016, 2, 1))
Exemple #10
0
    def test_lineage(self):
        dag = DAG(dag_id='test_prepare_lineage', start_date=DEFAULT_DATE)

        f1s = "/tmp/does_not_exist_1-{}"
        f2s = "/tmp/does_not_exist_2-{}"
        f3s = "/tmp/does_not_exist_3"
        file1 = File(f1s.format("{{ execution_date }}"))
        file2 = File(f2s.format("{{ execution_date }}"))
        file3 = File(f3s)

        with dag:
            op1 = DummyOperator(
                task_id='leave1',
                inlets=file1,
                outlets=[
                    file2,
                ],
            )
            op2 = DummyOperator(task_id='leave2')
            op3 = DummyOperator(task_id='upstream_level_1',
                                inlets=AUTO,
                                outlets=file3)
            op4 = DummyOperator(task_id='upstream_level_2')
            op5 = DummyOperator(task_id='upstream_level_3',
                                inlets=["leave1", "upstream_level_1"])

            op1.set_downstream(op3)
            op2.set_downstream(op3)
            op3.set_downstream(op4)
            op4.set_downstream(op5)

        dag.clear()

        # execution_date is set in the context in order to avoid creating task instances
        ctx1 = {
            "ti": TI(task=op1, execution_date=DEFAULT_DATE),
            "execution_date": DEFAULT_DATE
        }
        ctx2 = {
            "ti": TI(task=op2, execution_date=DEFAULT_DATE),
            "execution_date": DEFAULT_DATE
        }
        ctx3 = {
            "ti": TI(task=op3, execution_date=DEFAULT_DATE),
            "execution_date": DEFAULT_DATE
        }
        ctx5 = {
            "ti": TI(task=op5, execution_date=DEFAULT_DATE),
            "execution_date": DEFAULT_DATE
        }

        # prepare with manual inlets and outlets
        op1.pre_execute(ctx1)

        assert len(op1.inlets) == 1
        assert op1.inlets[0].url == f1s.format(DEFAULT_DATE)

        assert len(op1.outlets) == 1
        assert op1.outlets[0].url == f2s.format(DEFAULT_DATE)

        # post process with no backend
        op1.post_execute(ctx1)

        op2.pre_execute(ctx2)
        assert len(op2.inlets) == 0
        op2.post_execute(ctx2)

        op3.pre_execute(ctx3)
        assert len(op3.inlets) == 1
        assert op3.inlets[0].url == f2s.format(DEFAULT_DATE)
        assert op3.outlets[0] == file3
        op3.post_execute(ctx3)

        # skip 4

        op5.pre_execute(ctx5)
        assert len(op5.inlets) == 2
        op5.post_execute(ctx5)
Exemple #11
0
class TestBranchDayOfWeekOperator(unittest.TestCase):
    """
    Tests for BranchDayOfWeekOperator
    """
    @classmethod
    def setUpClass(cls):

        with create_session() as session:
            session.query(DagRun).delete()
            session.query(TI).delete()

    def setUp(self):
        self.dag = DAG(
            "branch_day_of_week_operator_test",
            start_date=DEFAULT_DATE,
            schedule_interval=INTERVAL,
        )
        self.branch_1 = DummyOperator(task_id="branch_1", dag=self.dag)
        self.branch_2 = DummyOperator(task_id="branch_2", dag=self.dag)
        self.branch_3 = None

    def tearDown(self):

        with create_session() as session:
            session.query(DagRun).delete()
            session.query(TI).delete()

    def _assert_task_ids_match_states(self, dr, task_ids_to_states):
        """Helper that asserts task instances with a given id are in a given state"""
        tis = dr.get_task_instances()
        for ti in tis:
            try:
                expected_state = task_ids_to_states[ti.task_id]
            except KeyError:
                raise ValueError(f'Invalid task id {ti.task_id} found!')
            else:
                self.assertEqual(
                    ti.state,
                    expected_state,
                    f"Task {ti.task_id} has state {ti.state} instead of expected {expected_state}",
                )

    @parameterized.expand([
        ("with-string", "Monday"),
        ("with-enum", WeekDay.MONDAY),
        ("with-enum-set", {WeekDay.MONDAY}),
        ("with-enum-set-2-items", {WeekDay.MONDAY, WeekDay.FRIDAY}),
        ("with-string-set", {"Monday"}),
        ("with-string-set-2-items", {"Monday", "Friday"}),
    ])
    @freeze_time("2021-01-25")  # Monday
    def test_branch_follow_true(self, _, weekday):
        """Checks if BranchDayOfWeekOperator follows true branch"""
        print(datetime.datetime.now())
        branch_op = BranchDayOfWeekOperator(
            task_id="make_choice",
            follow_task_ids_if_true=["branch_1", "branch_2"],
            follow_task_ids_if_false="branch_3",
            week_day=weekday,
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.branch_3 = DummyOperator(task_id="branch_3", dag=self.dag)
        self.branch_3.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        self._assert_task_ids_match_states(
            dr,
            {
                'make_choice': State.SUCCESS,
                'branch_1': State.NONE,
                'branch_2': State.NONE,
                'branch_3': State.SKIPPED,
            },
        )

    @freeze_time("2021-01-25")  # Monday
    def test_branch_follow_true_with_execution_date(self):
        """Checks if BranchDayOfWeekOperator follows true branch when set use_task_execution_day """

        branch_op = BranchDayOfWeekOperator(
            task_id="make_choice",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            week_day="Wednesday",
            use_task_execution_day=
            True,  # We compare to DEFAULT_DATE which is Wednesday
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        self._assert_task_ids_match_states(
            dr,
            {
                'make_choice': State.SUCCESS,
                'branch_1': State.NONE,
                'branch_2': State.SKIPPED,
            },
        )

    @freeze_time("2021-01-25")  # Monday
    def test_branch_follow_false(self):
        """Checks if BranchDayOfWeekOperator follow false branch"""

        branch_op = BranchDayOfWeekOperator(
            task_id="make_choice",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            week_day="Sunday",
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        self._assert_task_ids_match_states(
            dr,
            {
                'make_choice': State.SUCCESS,
                'branch_1': State.SKIPPED,
                'branch_2': State.NONE,
            },
        )

    def test_branch_with_no_weekday(self):
        """Check if BranchDayOfWeekOperator raises exception on missing weekday"""
        with self.assertRaises(AirflowException):
            BranchDayOfWeekOperator(  # pylint: disable=missing-kwoa
                task_id="make_choice",
                follow_task_ids_if_true="branch_1",
                follow_task_ids_if_false="branch_2",
                dag=self.dag,
            )

    def test_branch_with_invalid_type(self):
        """Check if BranchDayOfWeekOperator raises exception on unsupported weekday type"""
        invalid_week_day = ['Monday']
        with pytest.raises(
                TypeError,
                match='Unsupported Type for week_day parameter:'
                ' {}. It should be one of str, set or '
                'Weekday enum type'.format(type(invalid_week_day)),
        ):
            BranchDayOfWeekOperator(
                task_id="make_choice",
                follow_task_ids_if_true="branch_1",
                follow_task_ids_if_false="branch_2",
                week_day=invalid_week_day,
                dag=self.dag,
            )

    def test_weekday_branch_invalid_weekday_number(self):
        """Check if BranchDayOfWeekOperator raises exception on wrong value of weekday"""
        invalid_week_day = 'Thsday'
        with pytest.raises(
                AttributeError,
                match=f'Invalid Week Day passed: "{invalid_week_day}"'):
            BranchDayOfWeekOperator(
                task_id="make_choice",
                follow_task_ids_if_true="branch_1",
                follow_task_ids_if_false="branch_2",
                week_day=invalid_week_day,
                dag=self.dag,
            )

    @freeze_time("2021-01-25")  # Monday
    def test_branch_xcom_push_true_branch(self):
        """Check if BranchDayOfWeekOperator push to xcom value of follow_task_ids_if_true"""
        branch_op = BranchDayOfWeekOperator(
            task_id="make_choice",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            week_day="Monday",
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        for ti in tis:
            if ti.task_id == 'make_choice':
                assert ti.xcom_pull(task_ids='make_choice') == 'branch_1'
Exemple #12
0
    'execution_timeout': timedelta(seconds=300),
    'working_dir': os.path.dirname(os.path.abspath(__file__))
}

with DAG(
    'cpv_dwh_transformations',
    default_args=default_args,
    description='Refresh CPV data in the Data Warehouse layer (DWH)',
    schedule_interval=None,
    tags=['dend', 'cpv', 'dwh']
) as dag:
    _docs_md_fp = os.path.join(default_args['working_dir'], 'Readme.md')
    dag.doc_md = open(_docs_md_fp, 'r').read()

    start_refresh = DummyOperator(
        task_id='start_refresh'
    )

    create_cpv_dwh = RedshiftOperator(
        task_id='create_cpv_dwh',
        sql='schema_cpv_dwh.sql'
    )
    refresh_cpv_dwh = RedshiftOperator(
        task_id='refresh_cpv_dwh',
        sql='refresh_cpv.sql'
    )

    q_check = RedshiftQualityCheck(
        task_id='quality_check',
        schema="dwh",
        table="cpv_attributes",
Exemple #13
0

def _fetch_dataset_old():
    print("Fetching data (OLD)...")


def _fetch_dataset_new():
    print("Fetching data (NEW)...")


with DAG(
        dag_id="03_branching",
        start_date=airflow.utils.dates.days_ago(3),
        schedule_interval="@daily",
) as dag:
    start = DummyOperator(task_id="start")

    pick_branch = BranchPythonOperator(task_id="pick_branch",
                                       python_callable=_pick_branch)

    fetch_dataset_old = PythonOperator(task_id="fetch_dataset_old",
                                       python_callable=_fetch_dataset_old)

    fetch_dataset_new = PythonOperator(task_id="fetch_dataset_new",
                                       python_callable=_fetch_dataset_new)

    fetch_another_dataset = DummyOperator(task_id="fetch_another_dataset")

    join_datasets = DummyOperator(task_id="join_datasets",
                                  trigger_rule="none_failed")
from airflow.utils.dates import days_ago

DAG_NAME = 'example_subdag_operator'

args = {
    'owner': 'airflow',
}

dag = DAG(dag_id=DAG_NAME,
          default_args=args,
          start_date=days_ago(2),
          schedule_interval="@once",
          tags=['example'])

start = DummyOperator(
    task_id='start',
    dag=dag,
)

section_1 = SubDagOperator(
    task_id='section-1',
    subdag=subdag(DAG_NAME, 'section-1', args),
    dag=dag,
)

some_other_task = DummyOperator(
    task_id='some-other-task',
    dag=dag,
)

section_2 = SubDagOperator(
    task_id='section-2',
Exemple #15
0
from airflow import DAG
from airflow.operators.dummy import DummyOperator
from airflow.providers.docker.operators.docker import DockerOperator
from airflow.utils.dates import days_ago
from airflow.models import Variable
from airflow.sensors.filesystem import FileSensor

from utils import DEFAULT_VOLUME, default_args

with DAG("DAG3_inference",
         default_args=default_args,
         schedule_interval="@daily",
         start_date=days_ago(3)) as dag:

    start_task = DummyOperator(task_id='start-prediction')

    data_await = FileSensor(
        filepath='/opt/airflow/data/raw/{{ ds }}/data.csv',
        task_id="await-data",
        poke_interval=10,
        retries=100,
    )
    model_await = FileSensor(
        filepath='/opt/airflow/{{ var.value.model_dir }}/model.pkl',
        task_id="await-model",
        poke_interval=10,
        retries=100,
    )
    preprocessing = DockerOperator(
        task_id="preprocessing",
Exemple #16
0
        start_date=datetime(2021, 8, 13),
        schedule_interval="@daily",
        catchup=False,
        default_args=
    {
        "retries": 1,
        "retry_delay": timedelta(minutes=3),
        "azure_data_factory_conn_id": "azure_data_factory",
        "factory_name":
        "my-data-factory",  # This can also be specified in the ADF connection.
        "resource_group_name":
        "my-resource-group",  # This can also be specified in the ADF connection.
    },
        default_view="graph",
) as dag:
    begin = DummyOperator(task_id="begin")
    end = DummyOperator(task_id="end")

    # [START howto_operator_adf_run_pipeline]
    run_pipeline1 = AzureDataFactoryRunPipelineOperator(
        task_id="run_pipeline1",
        pipeline_name="pipeline1",
        parameters={"myParam": "value"},
    )
    # [END howto_operator_adf_run_pipeline]

    # [START howto_operator_adf_run_pipeline_async]
    run_pipeline2 = AzureDataFactoryRunPipelineOperator(
        task_id="run_pipeline2",
        pipeline_name="pipeline2",
        wait_for_termination=False,
Exemple #17
0
    "email_on_retry": False,
    "retries": 1,
    "retry_delay": timedelta(minutes=5),
}

with DAG(
        dag_id="covid_data_dag",
        default_args=default_args,
        description=
        "DAG to update Covid 19 data daily to push to a Postgres database.",
        schedule_interval='30 9 * * *',
        start_date=datetime(2021, 8, 24),
) as dag:

    # Initiate tasks
    task_1 = DummyOperator(task_id="Initiate_DAG")

    task_2 = PythonOperator(
        task_id="dashboard_update",
        python_callable=covid_19_dashboard_update,
        op_kwargs={
            "username": config.username,
            "password": passwords_dict.get('postgres_password'),
            "database": config.database,
            "table_name": config.table_name,
            "columns": config.columns,
            "geo_ids_url": config.geo_ids_url,
        },
    )

    task_1 >> task_2
Exemple #18
0
 def test_bad_trigger_rule(self):
     with self.assertRaises(AirflowException):
         DummyOperator(task_id='test_bad_trigger',
                       trigger_rule="non_existent",
                       dag=self.dag)
Exemple #19
0
    'depends_on_past': False,
    "retries": 3,
    "retry_delay": timedelta(seconds=5)
}

DAYS = 1
START_DATE = datetime(2021, 3, 28, 0, 0, 0)

with DAG(dag_id=DAG_ID,
         default_args=default_args,
         description="test hourly dags",
         schedule_interval="@hourly",
         start_date=START_DATE,
         tags=[WORKFLOW_ID],
         user_defined_macros=default_args) as dag:
    start = DummyOperator(task_id="start")
    command = ("echo prev_ds :{{ prev_ds }}, "
               "ds :{{ ds }}, "
               "next_ds :{{ next_ds }}, "
               "next_execution_date : {{ next_execution_date }}")

    t1 = BashOperator(task_id='t1', bash_command=command)

    t2 = BashOperator(task_id='t2', bash_command=command)

    t3 = BashOperator(task_id='t3', bash_command=command)

    end = DummyOperator(task_id="end")

    start >> t1 >> t2 >> t3 >> end
Exemple #20
0
def test_build_task_group_with_prefix():
    """
    Tests that prefix_group_id turns on/off prefixing of task_id with group_id.
    """
    execution_date = pendulum.parse("20200101")
    with DAG("test_build_task_group_with_prefix",
             start_date=execution_date) as dag:
        task1 = DummyOperator(task_id="task1")
        with TaskGroup("group234", prefix_group_id=False) as group234:
            task2 = DummyOperator(task_id="task2")

            with TaskGroup("group34") as group34:
                task3 = DummyOperator(task_id="task3")

                with TaskGroup("group4", prefix_group_id=False) as group4:
                    task4 = DummyOperator(task_id="task4")

        task5 = DummyOperator(task_id="task5")
        task1 >> group234
        group34 >> task5

    assert task2.task_id == "task2"
    assert group34.group_id == "group34"
    assert task3.task_id == "group34.task3"
    assert group4.group_id == "group34.group4"
    assert task4.task_id == "task4"
    assert task5.task_id == "task5"
    assert group234.get_child_by_label("task2") == task2
    assert group234.get_child_by_label("group34") == group34
    assert group4.get_child_by_label("task4") == task4

    assert extract_node_id(
        task_group_to_dict(dag.task_group), include_label=True) == {
            'id':
            None,
            'label':
            None,
            'children': [
                {
                    'id':
                    'group234',
                    'label':
                    'group234',
                    'children': [
                        {
                            'id':
                            'group34',
                            'label':
                            'group34',
                            'children': [
                                {
                                    'id':
                                    'group34.group4',
                                    'label':
                                    'group4',
                                    'children': [{
                                        'id': 'task4',
                                        'label': 'task4'
                                    }],
                                },
                                {
                                    'id': 'group34.task3',
                                    'label': 'task3'
                                },
                                {
                                    'id': 'group34.downstream_join_id',
                                    'label': ''
                                },
                            ],
                        },
                        {
                            'id': 'task2',
                            'label': 'task2'
                        },
                        {
                            'id': 'group234.upstream_join_id',
                            'label': ''
                        },
                    ],
                },
                {
                    'id': 'task1',
                    'label': 'task1'
                },
                {
                    'id': 'task5',
                    'label': 'task5'
                },
            ],
        }
Exemple #21
0
``none_failed_or_skipped`` trigger rule such that they are skipped whenever their corresponding
``BranchPythonOperator`` are skipped.
"""

from airflow.models import DAG
from airflow.operators.dummy import DummyOperator
from airflow.operators.python import BranchPythonOperator
from airflow.utils.dates import days_ago

with DAG(dag_id="example_nested_branch_dag",
         start_date=days_ago(2),
         schedule_interval="@daily",
         tags=["example"]) as dag:
    branch_1 = BranchPythonOperator(task_id="branch_1",
                                    python_callable=lambda: "true_1")
    join_1 = DummyOperator(task_id="join_1",
                           trigger_rule="none_failed_or_skipped")
    true_1 = DummyOperator(task_id="true_1")
    false_1 = DummyOperator(task_id="false_1")
    branch_2 = BranchPythonOperator(task_id="branch_2",
                                    python_callable=lambda: "true_2")
    join_2 = DummyOperator(task_id="join_2",
                           trigger_rule="none_failed_or_skipped")
    true_2 = DummyOperator(task_id="true_2")
    false_2 = DummyOperator(task_id="false_2")
    false_3 = DummyOperator(task_id="false_3")

    branch_1 >> true_1 >> join_1
    branch_1 >> false_1 >> branch_2 >> [true_2, false_2
                                        ] >> join_2 >> false_3 >> join_1
Exemple #22
0
def test_sub_dag_task_group():
    """
    Tests dag.sub_dag() updates task_group correctly.
    """
    execution_date = pendulum.parse("20200101")
    with DAG("test_test_task_group_sub_dag", start_date=execution_date) as dag:
        task1 = DummyOperator(task_id="task1")
        with TaskGroup("group234") as group234:
            _ = DummyOperator(task_id="task2")

            with TaskGroup("group34") as group34:
                _ = DummyOperator(task_id="task3")
                _ = DummyOperator(task_id="task4")

        with TaskGroup("group6") as group6:
            _ = DummyOperator(task_id="task6")

        task7 = DummyOperator(task_id="task7")
        task5 = DummyOperator(task_id="task5")

        task1 >> group234
        group34 >> task5
        group234 >> group6
        group234 >> task7

    subdag = dag.sub_dag(task_ids_or_regex="task5",
                         include_upstream=True,
                         include_downstream=False)

    assert extract_node_id(task_group_to_dict(subdag.task_group)) == {
        'id':
        None,
        'children': [
            {
                'id':
                'group234',
                'children': [
                    {
                        'id':
                        'group234.group34',
                        'children': [
                            {
                                'id': 'group234.group34.task3'
                            },
                            {
                                'id': 'group234.group34.task4'
                            },
                            {
                                'id': 'group234.group34.downstream_join_id'
                            },
                        ],
                    },
                    {
                        'id': 'group234.upstream_join_id'
                    },
                ],
            },
            {
                'id': 'task1'
            },
            {
                'id': 'task5'
            },
        ],
    }

    edges = dag_edges(subdag)
    assert sorted((e["source_id"], e["target_id"]) for e in edges) == [
        ('group234.group34.downstream_join_id', 'task5'),
        ('group234.group34.task3', 'group234.group34.downstream_join_id'),
        ('group234.group34.task4', 'group234.group34.downstream_join_id'),
        ('group234.upstream_join_id', 'group234.group34.task3'),
        ('group234.upstream_join_id', 'group234.group34.task4'),
        ('task1', 'group234.upstream_join_id'),
    ]

    subdag_task_groups = subdag.task_group.get_task_group_dict()
    assert subdag_task_groups.keys() == {None, "group234", "group234.group34"}

    included_group_ids = {"group234", "group234.group34"}
    included_task_ids = {
        'group234.group34.task3', 'group234.group34.task4', 'task1', 'task5'
    }

    for task_group in subdag_task_groups.values():
        assert task_group.upstream_group_ids.issubset(included_group_ids)
        assert task_group.downstream_group_ids.issubset(included_group_ids)
        assert task_group.upstream_task_ids.issubset(included_task_ids)
        assert task_group.downstream_task_ids.issubset(included_task_ids)

    for task in subdag.task_group:
        assert task.upstream_task_ids.issubset(included_task_ids)
        assert task.downstream_task_ids.issubset(included_task_ids)
import datetime as dt

from airflow import DAG
from airflow.operators.dummy import DummyOperator
from airflow.operators.latest_only import LatestOnlyOperator
from airflow.utils.dates import days_ago
from airflow.utils.trigger_rule import TriggerRule

"""
옛날 task 는 호출 안되고, 가장 최신 task만 호출됨. backfill이 안된다고 보면 될듯.
task2는 스킵되지 않았는데, task3,4는 스킵됨. latest_only와 연결된 task1,3,4가 모두 스킵되서, task2가 호출되어도 task3,4는 호출 안되는듯
task4의 TriggerRule.ALL_DONE으로 하면, 모든게 다 처리되고, 앞에 것이 스킵 되어도, 수행됨.
"""
dag = DAG(
    dag_id='latest_only_operator_test',
    schedule_interval=dt.timedelta(hours=4),
    start_date=days_ago(2),
    tags=['example3'],
)

latest_only = LatestOnlyOperator(task_id='latest_only', dag=dag)
task1 = DummyOperator(task_id='task1', dag=dag)
task2 = DummyOperator(task_id='task2', dag=dag)
task3 = DummyOperator(task_id='task3', dag=dag)
task4 = DummyOperator(task_id='task4', dag=dag)
# task4 = DummyOperator(task_id='task4', dag=dag, trigger_rule=TriggerRule.ALL_DONE)

latest_only >> task1 >> [task3, task4]
task2 >> [task3, task4]
Exemple #24
0
def test_dag_edges():
    execution_date = pendulum.parse("20200101")
    with DAG("test_dag_edges", start_date=execution_date) as dag:
        task1 = DummyOperator(task_id="task1")
        with TaskGroup("group_a") as group_a:
            with TaskGroup("group_b") as group_b:
                task2 = DummyOperator(task_id="task2")
                task3 = DummyOperator(task_id="task3")
                task4 = DummyOperator(task_id="task4")
                task2 >> [task3, task4]

            task5 = DummyOperator(task_id="task5")

            task5 << group_b

        task1 >> group_a

        with TaskGroup("group_c") as group_c:
            task6 = DummyOperator(task_id="task6")
            task7 = DummyOperator(task_id="task7")
            task8 = DummyOperator(task_id="task8")
            [task6, task7] >> task8
            group_a >> group_c

        task5 >> task8

        task9 = DummyOperator(task_id="task9")
        task10 = DummyOperator(task_id="task10")

        group_c >> [task9, task10]

        with TaskGroup("group_d") as group_d:
            task11 = DummyOperator(task_id="task11")
            task12 = DummyOperator(task_id="task12")
            task11 >> task12

        group_d << group_c

    nodes = task_group_to_dict(dag.task_group)
    edges = dag_edges(dag)

    assert extract_node_id(nodes) == {
        'id':
        None,
        'children': [
            {
                'id':
                'group_a',
                'children': [
                    {
                        'id':
                        'group_a.group_b',
                        'children': [
                            {
                                'id': 'group_a.group_b.task2'
                            },
                            {
                                'id': 'group_a.group_b.task3'
                            },
                            {
                                'id': 'group_a.group_b.task4'
                            },
                            {
                                'id': 'group_a.group_b.downstream_join_id'
                            },
                        ],
                    },
                    {
                        'id': 'group_a.task5'
                    },
                    {
                        'id': 'group_a.upstream_join_id'
                    },
                    {
                        'id': 'group_a.downstream_join_id'
                    },
                ],
            },
            {
                'id':
                'group_c',
                'children': [
                    {
                        'id': 'group_c.task6'
                    },
                    {
                        'id': 'group_c.task7'
                    },
                    {
                        'id': 'group_c.task8'
                    },
                    {
                        'id': 'group_c.upstream_join_id'
                    },
                    {
                        'id': 'group_c.downstream_join_id'
                    },
                ],
            },
            {
                'id':
                'group_d',
                'children': [
                    {
                        'id': 'group_d.task11'
                    },
                    {
                        'id': 'group_d.task12'
                    },
                    {
                        'id': 'group_d.upstream_join_id'
                    },
                ],
            },
            {
                'id': 'task1'
            },
            {
                'id': 'task10'
            },
            {
                'id': 'task9'
            },
        ],
    }

    assert sorted((e["source_id"], e["target_id"]) for e in edges) == [
        ('group_a.downstream_join_id', 'group_c.upstream_join_id'),
        ('group_a.group_b.downstream_join_id', 'group_a.task5'),
        ('group_a.group_b.task2', 'group_a.group_b.task3'),
        ('group_a.group_b.task2', 'group_a.group_b.task4'),
        ('group_a.group_b.task3', 'group_a.group_b.downstream_join_id'),
        ('group_a.group_b.task4', 'group_a.group_b.downstream_join_id'),
        ('group_a.task5', 'group_a.downstream_join_id'),
        ('group_a.task5', 'group_c.task8'),
        ('group_a.upstream_join_id', 'group_a.group_b.task2'),
        ('group_c.downstream_join_id', 'group_d.upstream_join_id'),
        ('group_c.downstream_join_id', 'task10'),
        ('group_c.downstream_join_id', 'task9'),
        ('group_c.task6', 'group_c.task8'),
        ('group_c.task7', 'group_c.task8'),
        ('group_c.task8', 'group_c.downstream_join_id'),
        ('group_c.upstream_join_id', 'group_c.task6'),
        ('group_c.upstream_join_id', 'group_c.task7'),
        ('group_d.task11', 'group_d.task12'),
        ('group_d.upstream_join_id', 'group_d.task11'),
        ('task1', 'group_a.upstream_join_id'),
    ]
from airflow.utils.dates import days_ago

args = {
    "owner": "airflow",
}

dag = DAG(
    dag_id="example_branch_datetime_operator",
    start_date=days_ago(2),
    default_args=args,
    tags=["example"],
    schedule_interval="@daily",
)

# [START howto_branch_datetime_operator]
dummy_task_1 = DummyOperator(task_id='date_in_range', dag=dag)
dummy_task_2 = DummyOperator(task_id='date_outside_range', dag=dag)

cond1 = BranchDateTimeOperator(
    task_id='datetime_branch',
    follow_task_ids_if_true=['date_in_range'],
    follow_task_ids_if_false=['date_outside_range'],
    target_upper=datetime.datetime(2020, 10, 10, 15, 0, 0),
    target_lower=datetime.datetime(2020, 10, 10, 14, 0, 0),
    dag=dag,
)

# Run dummy_task_1 if cond1 executes between 2020-10-10 14:00:00 and 2020-10-10 15:00:00
cond1 >> [dummy_task_1, dummy_task_2]
# [END howto_branch_datetime_operator]
    return 'in_accurate'


with DAG('xcom_dag',
         schedule_interval='@daily',
         default_args=default_args,
         catchup=False) as dag:

    downloading_data = BashOperator(task_id='downloading_data',
                                    bash_command='sleep 3',
                                    do_xcom_push=False)

    with TaskGroup('processing_tasks') as processing_tasks:
        training_model_a = PythonOperator(task_id='training_model_a',
                                          python_callable=_training_model)

        training_model_b = PythonOperator(task_id='training_model_b',
                                          python_callable=_training_model)

        training_model_c = PythonOperator(task_id='training_model_c',
                                          python_callable=_training_model)

    choose_model = BranchPythonOperator(task_id='task_4',
                                        python_callable=_choose_best_model)

    accurate = DummyOperator(task_id='accurate')

    in_accurate = DummyOperator(task_id='in_accurate')
    downloading_data >> processing_tasks >> choose_model
    choose_model >> [accurate, in_accurate]
Exemple #27
0
def _latest_only(**context):
    now = pendulum.now("UTC")
    left_window = context["dag"].following_schedule(context["execution_date"])
    right_window = context["dag"].following_schedule(left_window)

    if not left_window < now <= right_window:
        raise AirflowSkipException()


with DAG(
    dag_id="06_condition_dag",
    start_date=airflow.utils.dates.days_ago(3),
    schedule_interval="@daily",
) as dag:
    start = DummyOperator(task_id="start")

    pick_erp = BranchPythonOperator(
        task_id="pick_erp_system", python_callable=_pick_erp_system
    )

    fetch_sales_old = DummyOperator(task_id="fetch_sales_old")
    clean_sales_old = DummyOperator(task_id="clean_sales_old")

    fetch_sales_new = DummyOperator(task_id="fetch_sales_new")
    clean_sales_new = DummyOperator(task_id="clean_sales_new")

    join_erp = DummyOperator(task_id="join_erp_branch", trigger_rule="none_failed")

    fetch_weather = DummyOperator(task_id="fetch_weather")
    clean_weather = DummyOperator(task_id="clean_weather")
Exemple #28
0
class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
    """
    Test for SQL Branch Operator
    """
    @classmethod
    def setUpClass(cls):
        super().setUpClass()

        with create_session() as session:
            session.query(DagRun).delete()
            session.query(TI).delete()

    def setUp(self):
        super().setUp()
        self.dag = DAG(
            "sql_branch_operator_test",
            default_args={
                "owner": "airflow",
                "start_date": DEFAULT_DATE
            },
            schedule_interval=INTERVAL,
        )
        self.branch_1 = DummyOperator(task_id="branch_1", dag=self.dag)
        self.branch_2 = DummyOperator(task_id="branch_2", dag=self.dag)
        self.branch_3 = None

    def tearDown(self):
        super().tearDown()

        with create_session() as session:
            session.query(DagRun).delete()
            session.query(TI).delete()

    def test_unsupported_conn_type(self):
        """Check if BranchSQLOperator throws an exception for unsupported connection type"""
        op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="redis_default",
            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        with pytest.raises(AirflowException):
            op.run(start_date=DEFAULT_DATE,
                   end_date=DEFAULT_DATE,
                   ignore_ti_state=True)

    def test_invalid_conn(self):
        """Check if BranchSQLOperator throws an exception for invalid connection"""
        op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="invalid_connection",
            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        with pytest.raises(AirflowException):
            op.run(start_date=DEFAULT_DATE,
                   end_date=DEFAULT_DATE,
                   ignore_ti_state=True)

    def test_invalid_follow_task_true(self):
        """Check if BranchSQLOperator throws an exception for invalid connection"""
        op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="invalid_connection",
            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
            follow_task_ids_if_true=None,
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        with pytest.raises(AirflowException):
            op.run(start_date=DEFAULT_DATE,
                   end_date=DEFAULT_DATE,
                   ignore_ti_state=True)

    def test_invalid_follow_task_false(self):
        """Check if BranchSQLOperator throws an exception for invalid connection"""
        op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="invalid_connection",
            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false=None,
            dag=self.dag,
        )

        with pytest.raises(AirflowException):
            op.run(start_date=DEFAULT_DATE,
                   end_date=DEFAULT_DATE,
                   ignore_ti_state=True)

    @pytest.mark.backend("mysql")
    def test_sql_branch_operator_mysql(self):
        """Check if BranchSQLOperator works with backend"""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )
        branch_op.run(start_date=DEFAULT_DATE,
                      end_date=DEFAULT_DATE,
                      ignore_ti_state=True)

    @pytest.mark.backend("postgres")
    def test_sql_branch_operator_postgres(self):
        """Check if BranchSQLOperator works with backend"""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="postgres_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )
        branch_op.run(start_date=DEFAULT_DATE,
                      end_date=DEFAULT_DATE,
                      ignore_ti_state=True)

    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
    def test_branch_single_value_with_dag_run(self, mock_get_db_hook):
        """Check BranchSQLOperator branch operation"""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_get_records = mock_get_db_hook.return_value.get_first

        mock_get_records.return_value = 1

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        for ti in tis:
            if ti.task_id == "make_choice":
                assert ti.state == State.SUCCESS
            elif ti.task_id == "branch_1":
                assert ti.state == State.NONE
            elif ti.task_id == "branch_2":
                assert ti.state == State.SKIPPED
            else:
                raise ValueError(f"Invalid task id {ti.task_id} found!")

    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
    def test_branch_true_with_dag_run(self, mock_get_db_hook):
        """Check BranchSQLOperator branch operation"""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_get_records = mock_get_db_hook.return_value.get_first

        for true_value in SUPPORTED_TRUE_VALUES:
            mock_get_records.return_value = true_value

            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

            tis = dr.get_task_instances()
            for ti in tis:
                if ti.task_id == "make_choice":
                    assert ti.state == State.SUCCESS
                elif ti.task_id == "branch_1":
                    assert ti.state == State.NONE
                elif ti.task_id == "branch_2":
                    assert ti.state == State.SKIPPED
                else:
                    raise ValueError(f"Invalid task id {ti.task_id} found!")

    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
    def test_branch_false_with_dag_run(self, mock_get_db_hook):
        """Check BranchSQLOperator branch operation"""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_get_records = mock_get_db_hook.return_value.get_first

        for false_value in SUPPORTED_FALSE_VALUES:
            mock_get_records.return_value = false_value
            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

            tis = dr.get_task_instances()
            for ti in tis:
                if ti.task_id == "make_choice":
                    assert ti.state == State.SUCCESS
                elif ti.task_id == "branch_1":
                    assert ti.state == State.SKIPPED
                elif ti.task_id == "branch_2":
                    assert ti.state == State.NONE
                else:
                    raise ValueError(f"Invalid task id {ti.task_id} found!")

    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
    def test_branch_list_with_dag_run(self, mock_get_db_hook):
        """Checks if the BranchSQLOperator supports branching off to a list of tasks."""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true=["branch_1", "branch_2"],
            follow_task_ids_if_false="branch_3",
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.branch_3 = DummyOperator(task_id="branch_3", dag=self.dag)
        self.branch_3.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_get_records = mock_get_db_hook.return_value.get_first
        mock_get_records.return_value = [["1"]]

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        for ti in tis:
            if ti.task_id == "make_choice":
                assert ti.state == State.SUCCESS
            elif ti.task_id == "branch_1":
                assert ti.state == State.NONE
            elif ti.task_id == "branch_2":
                assert ti.state == State.NONE
            elif ti.task_id == "branch_3":
                assert ti.state == State.SKIPPED
            else:
                raise ValueError(f"Invalid task id {ti.task_id} found!")

    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
    def test_invalid_query_result_with_dag_run(self, mock_get_db_hook):
        """Check BranchSQLOperator branch operation"""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_get_records = mock_get_db_hook.return_value.get_first

        mock_get_records.return_value = ["Invalid Value"]

        with pytest.raises(AirflowException):
            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
    def test_with_skip_in_branch_downstream_dependencies(
            self, mock_get_db_hook):
        """Test SQL Branch with skipping all downstream dependencies"""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        branch_op >> self.branch_1 >> self.branch_2
        branch_op >> self.branch_2
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_get_records = mock_get_db_hook.return_value.get_first

        for true_value in SUPPORTED_TRUE_VALUES:
            mock_get_records.return_value = [true_value]

            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

            tis = dr.get_task_instances()
            for ti in tis:
                if ti.task_id == "make_choice":
                    assert ti.state == State.SUCCESS
                elif ti.task_id == "branch_1":
                    assert ti.state == State.NONE
                elif ti.task_id == "branch_2":
                    assert ti.state == State.NONE
                else:
                    raise ValueError(f"Invalid task id {ti.task_id} found!")

    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
    def test_with_skip_in_branch_downstream_dependencies2(
            self, mock_get_db_hook):
        """Test skipping downstream dependency for false condition"""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        branch_op >> self.branch_1 >> self.branch_2
        branch_op >> self.branch_2
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_get_records = mock_get_db_hook.return_value.get_first

        for false_value in SUPPORTED_FALSE_VALUES:
            mock_get_records.return_value = [false_value]

            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

            tis = dr.get_task_instances()
            for ti in tis:
                if ti.task_id == "make_choice":
                    assert ti.state == State.SUCCESS
                elif ti.task_id == "branch_1":
                    assert ti.state == State.SKIPPED
                elif ti.task_id == "branch_2":
                    assert ti.state == State.NONE
                else:
                    raise ValueError(f"Invalid task id {ti.task_id} found!")
def _wait_for_supermarket(supermarket_id_):
    supermarket_path = Path("/data/" + supermarket_id_)
    data_files = supermarket_path.glob("data-*.csv")
    success_file = supermarket_path / "_SUCCESS"
    return data_files and success_file.exists()


for supermarket_id in range(1, 5):
    wait = PythonSensor(
        task_id=f"wait_for_supermarket_{supermarket_id}",
        python_callable=_wait_for_supermarket,
        op_kwargs={"supermarket_id_": f"supermarket{supermarket_id}"},
        dag=dag1,
    )
    copy = DummyOperator(task_id=f"copy_to_raw_supermarket_{supermarket_id}",
                         dag=dag1)
    process = DummyOperator(task_id=f"process_supermarket_{supermarket_id}",
                            dag=dag1)
    trigger_create_metrics_dag = TriggerDagRunOperator(
        task_id=f"trigger_create_metrics_dag_supermarket_{supermarket_id}",
        trigger_dag_id="listing_6_04_dag02",
        dag=dag1,
    )
    wait >> copy >> process >> trigger_create_metrics_dag

compute_differences = DummyOperator(task_id="compute_differences", dag=dag2)
update_dashboard = DummyOperator(task_id="update_dashboard", dag=dag2)
notify_new_data = DummyOperator(task_id="notify_new_data", dag=dag2)
compute_differences >> update_dashboard
Exemple #30
0
 def subdag_b():
     subdag_b = DAG('nested_cycle.op_subdag_0.opSubdag_B',
                    default_args=default_args)
     DummyOperator(task_id='subdag_b.task', dag=subdag_b)
     return subdag_b