示例#1
0
    def test_branch_list_without_dag_run(self):
        """This checks if the BranchPythonOperator supports branching off to a list of tasks."""
        branch_op = BranchPythonOperator(
            task_id='make_choice',
            dag=self.dag,
            python_callable=lambda: ['branch_1', 'branch_2'])
        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag)
        self.branch_3.set_upstream(branch_op)
        self.dag.clear()

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        with create_session() as session:
            tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id,
                                           TI.execution_date == DEFAULT_DATE)

            expected = {
                "make_choice": State.SUCCESS,
                "branch_1": State.NONE,
                "branch_2": State.NONE,
                "branch_3": State.SKIPPED,
            }

            for ti in tis:
                if ti.task_id in expected:
                    self.assertEqual(ti.state, expected[ti.task_id])
                else:
                    raise Exception
示例#2
0
    def test_with_skip_in_branch_downstream_dependencies2(self):
        branch_op = BranchPythonOperator(task_id='make_choice',
                                         dag=self.dag,
                                         python_callable=lambda: 'branch_2')

        branch_op >> self.branch_1 >> self.branch_2
        branch_op >> self.branch_2
        self.dag.clear()

        dr = self.dag.create_dagrun(run_id="manual__",
                                    start_date=timezone.utcnow(),
                                    execution_date=DEFAULT_DATE,
                                    state=State.RUNNING)

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        for ti in tis:
            if ti.task_id == 'make_choice':
                self.assertEqual(ti.state, State.SUCCESS)
            elif ti.task_id == 'branch_1':
                self.assertEqual(ti.state, State.SKIPPED)
            elif ti.task_id == 'branch_2':
                self.assertEqual(ti.state, State.NONE)
            else:
                raise Exception
示例#3
0
    def test_with_dag_run(self):
        branch_op = BranchPythonOperator(task_id='make_choice',
                                         dag=self.dag,
                                         python_callable=lambda: 'branch_1')

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(run_type=DagRunType.MANUAL,
                                    start_date=timezone.utcnow(),
                                    execution_date=DEFAULT_DATE,
                                    state=State.RUNNING)

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        for ti in tis:
            if ti.task_id == 'make_choice':
                self.assertEqual(ti.state, State.SUCCESS)
            elif ti.task_id == 'branch_1':
                self.assertEqual(ti.state, State.NONE)
            elif ti.task_id == 'branch_2':
                self.assertEqual(ti.state, State.SKIPPED)
            else:
                raise ValueError(f'Invalid task id {ti.task_id} found!')
示例#4
0
    def test_with_skip_in_branch_downstream_dependencies2(self):
        branch_op = BranchPythonOperator(task_id='make_choice',
                                         dag=self.dag,
                                         python_callable=lambda: 'branch_2')

        branch_op >> self.branch_1 >> self.branch_2
        branch_op >> self.branch_2
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_type=DagRunType.MANUAL,
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        for ti in tis:
            if ti.task_id == 'make_choice':
                assert ti.state == State.SUCCESS
            elif ti.task_id == 'branch_1':
                assert ti.state == State.SKIPPED
            elif ti.task_id == 'branch_2':
                assert ti.state == State.NONE
            else:
                raise ValueError(f'Invalid task id {ti.task_id} found!')
示例#5
0
    def test_without_dag_run(self):
        """This checks the defensive against non existent tasks in a dag run"""
        branch_op = BranchPythonOperator(task_id='make_choice',
                                         dag=self.dag,
                                         python_callable=lambda: 'branch_1')
        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        with create_session() as session:
            tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id,
                                           TI.execution_date == DEFAULT_DATE)

            for ti in tis:
                if ti.task_id == 'make_choice':
                    self.assertEqual(ti.state, State.SUCCESS)
                elif ti.task_id == 'branch_1':
                    # should exist with state None
                    self.assertEqual(ti.state, State.NONE)
                elif ti.task_id == 'branch_2':
                    self.assertEqual(ti.state, State.SKIPPED)
                else:
                    raise Exception
示例#6
0
    def test_clear_skipped_downstream_task(self):
        """
        After a downstream task is skipped by BranchPythonOperator, clearing the skipped task
        should not cause it to be executed.
        """
        branch_op = BranchPythonOperator(task_id='make_choice',
                                         dag=self.dag,
                                         python_callable=lambda: 'branch_1')
        branches = [self.branch_1, self.branch_2]
        branch_op >> branches
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_type=DagRunType.MANUAL,
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        for task in branches:
            task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        for ti in tis:
            if ti.task_id == 'make_choice':
                assert ti.state == State.SUCCESS
            elif ti.task_id == 'branch_1':
                assert ti.state == State.SUCCESS
            elif ti.task_id == 'branch_2':
                assert ti.state == State.SKIPPED
            else:
                raise ValueError(f'Invalid task id {ti.task_id} found!')

        children_tis = [
            ti for ti in tis
            if ti.task_id in branch_op.get_direct_relative_ids()
        ]

        # Clear the children tasks.
        with create_session() as session:
            clear_task_instances(children_tis, session=session, dag=self.dag)

        # Run the cleared tasks again.
        for task in branches:
            task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        # Check if the states are correct after children tasks are cleared.
        for ti in dr.get_task_instances():
            if ti.task_id == 'make_choice':
                assert ti.state == State.SUCCESS
            elif ti.task_id == 'branch_1':
                assert ti.state == State.SUCCESS
            elif ti.task_id == 'branch_2':
                assert ti.state == State.SKIPPED
            else:
                raise ValueError(f'Invalid task id {ti.task_id} found!')
示例#7
0
def test_empty_branch(choice, expected_states):
    """
    Tests that BranchPythonOperator handles empty branches properly.
    """
    with DAG(
            'test_empty_branch',
            start_date=DEFAULT_DATE,
    ) as dag:
        branch = BranchPythonOperator(task_id='branch',
                                      python_callable=lambda: choice)
        task1 = DummyOperator(task_id='task1')
        join = DummyOperator(task_id='join',
                             trigger_rule="none_failed_or_skipped")

        branch >> [task1, join]
        task1 >> join

    dag.clear(start_date=DEFAULT_DATE)

    task_ids = ["branch", "task1", "join"]

    tis = {}
    for task_id in task_ids:
        task_instance = TI(dag.get_task(task_id), execution_date=DEFAULT_DATE)
        tis[task_id] = task_instance
        task_instance.run()

    def get_state(ti):
        ti.refresh_from_db()
        return ti.state

    assert [get_state(tis[task_id]) for task_id in task_ids] == expected_states
示例#8
0
def test_parent_skip_branch():
    """
    A simple DAG with a BranchPythonOperator that does not follow op2. NotPreviouslySkippedDep is not met.
    """
    with create_session() as session:
        session.query(DagRun).delete()
        session.query(TaskInstance).delete()
        start_date = pendulum.datetime(2020, 1, 1)
        dag = DAG("test_parent_skip_branch_dag",
                  schedule_interval=None,
                  start_date=start_date)
        dag.create_dagrun(run_type=DagRunType.MANUAL,
                          state=State.RUNNING,
                          execution_date=start_date)
        op1 = BranchPythonOperator(task_id="op1",
                                   python_callable=lambda: "op3",
                                   dag=dag)
        op2 = DummyOperator(task_id="op2", dag=dag)
        op3 = DummyOperator(task_id="op3", dag=dag)
        op1 >> [op2, op3]
        TaskInstance(op1, start_date).run()
        ti2 = TaskInstance(op2, start_date)
        dep = NotPreviouslySkippedDep()

        assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 1
        session.commit()
        assert not dep.is_met(ti2, session)
        assert ti2.state == State.SKIPPED
示例#9
0
def _training_model(model):
return randint(1, 10)
with DAG("my_dag",
start_date=datetime(2021, 1 ,1), 
schedule_interval='@daily', 
catchup=False) as dag:
training_model_tasks = [
PythonOperator(
task_id=f"training_model_{model_id}",
python_callable=_training_model,
op_kwargs={
"model": model_id
}
) for model_id in ['A', 'B', 'C']
]
choosing_best_model = BranchPythonOperator(
task_id="choosing_best_model",
python_callable=_choosing_best_model
)
accurate = BashOperator(
task_id="accurate",
bash_command="echo 'accurate'"
)
inaccurate = BashOperator(
task_id="inaccurate",
bash_command=" echo 'inaccurate'"
)
training_model_tasks >> choosing_best_model >> [accurate, inaccurate]
示例#10
0
    def test_xcom_push(self):
        branch_op = BranchPythonOperator(task_id='make_choice',
                                         dag=self.dag,
                                         python_callable=lambda: 'branch_1')

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(run_id="manual__",
                                    start_date=timezone.utcnow(),
                                    execution_date=DEFAULT_DATE,
                                    state=State.RUNNING)

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        for ti in tis:
            if ti.task_id == 'make_choice':
                self.assertEqual(ti.xcom_pull(task_ids='make_choice'),
                                 'branch_1')
示例#11
0
def model():
    """
    Model creating tasks such as optimization, training, evaluation and serving.
    """

    # Extract features
    extract_features = PythonOperator(
        task_id="extract_features",
        python_callable=cli.get_historical_features,
    )

    # Optimization
    optimization = BashOperator(
        task_id="optimization",
        bash_command="echo `tagifai optimize`",
    )

    # Train model
    train = BashOperator(
        task_id="train",
        bash_command="echo `tagifai train-model`",
    )

    # Evaluate model
    evaluate = BranchPythonOperator(  # BranchPythonOperator returns a task_id or [task_ids]
        task_id="evaluate",
        python_callable=_evaluate_model,
    )

    # Improved or regressed
    improved = BashOperator(
        task_id="improved",
        bash_command="echo IMPROVED",
    )
    regressed = BashOperator(
        task_id="regressed",
        bash_command="echo REGRESSED",
    )

    # Serve model
    serve = BashOperator(
        task_id="serve",  # push to GitHub to kick off serving workflows
        bash_command="echo served model",  # or to a purpose-built model server, etc.
    )

    # Notifications (use appropriate operators, ex. EmailOperator)
    report = BashOperator(task_id="report", bash_command="echo filed report")

    # Task relationships
    extract_features >> optimization >> train >> evaluate >> [improved, regressed]
    improved >> serve
    regressed >> report
示例#12
0
def update():
    """
    Model updating tasks such as monitoring, retraining, etc.
    """
    # Monitoring (inputs, predictions, etc.)
    # Considers thresholds, windows, frequency, etc.
    monitoring = BashOperator(
        task_id="monitoring",
        bash_command="echo monitoring",
    )

    # Update policy engine (continue, improve, rollback, etc.)
    update_policy_engine = BranchPythonOperator(
        task_id="update_policy_engine",
        python_callable=_update_policy_engine,
    )

    # Policies
    _continue = BashOperator(
        task_id="continue",
        bash_command="echo continue",
    )
    inspect = BashOperator(
        task_id="inspect",
        bash_command="echo inspect",
    )
    improve = BashOperator(
        task_id="improve",
        bash_command="echo improve",
    )
    rollback = BashOperator(
        task_id="rollback",
        bash_command="echo rollback",
    )

    # Compose retraining dataset
    # Labeling, QA, augmentation, upsample poor slices, weight samples, etc.
    compose_retraining_dataset = BashOperator(
        task_id="compose_retraining_dataset",
        bash_command="echo compose retraining dataset",
    )

    # Retrain (initiates model creation workflow)
    retrain = BashOperator(
        task_id="retrain",
        bash_command="echo retrain",
    )

    # Task relationships
    monitoring >> update_policy_engine >> [_continue, inspect, improve, rollback]
    improve >> compose_retraining_dataset >> retrain
示例#13
0
def test_parent_not_executed():
    """
    A simple DAG with a BranchPythonOperator that does not follow op2. Parent task is not yet
    executed (no xcom data). NotPreviouslySkippedDep is met (no decision).
    """
    start_date = pendulum.datetime(2020, 1, 1)
    dag = DAG("test_parent_not_executed_dag",
              schedule_interval=None,
              start_date=start_date)
    op1 = BranchPythonOperator(task_id="op1",
                               python_callable=lambda: "op3",
                               dag=dag)
    op2 = DummyOperator(task_id="op2", dag=dag)
    op3 = DummyOperator(task_id="op3", dag=dag)
    op1 >> [op2, op3]

    ti2 = TaskInstance(op2, start_date)

    with create_session() as session:
        dep = NotPreviouslySkippedDep()
        assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 0
        assert dep.is_met(ti2, session)
        assert ti2.state == State.NONE
示例#14
0
def test_parent_follow_branch():
    """
    A simple DAG with a BranchPythonOperator that follows op2. NotPreviouslySkippedDep is met.
    """
    start_date = pendulum.datetime(2020, 1, 1)
    dag = DAG("test_parent_follow_branch_dag",
              schedule_interval=None,
              start_date=start_date)
    dag.create_dagrun(run_type=DagRunType.MANUAL,
                      state=State.RUNNING,
                      execution_date=start_date)
    op1 = BranchPythonOperator(task_id="op1",
                               python_callable=lambda: "op2",
                               dag=dag)
    op2 = DummyOperator(task_id="op2", dag=dag)
    op1 >> op2
    TaskInstance(op1, start_date).run()
    ti2 = TaskInstance(op2, start_date)

    with create_session() as session:
        dep = NotPreviouslySkippedDep()
        assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 0
        assert dep.is_met(ti2, session)
        assert ti2.state != State.SKIPPED
示例#15
0
    if (best_accuracy > 8):
        return 'accurate'
    return 'inaccurate'


def _training_model():
    return randint(1, 10)


with DAG("first_dag",
         start_date=datetime(2021, 1, 1),
         schedule_interval="@daily",
         catchup=False) as dag:

    training_model_A = PythonOperator(task_id="training_model_A",
                                      python_callable=_training_model)

    training_model_B = PythonOperator(task_id="training_model_B",
                                      python_callable=_training_model)

    training_model_C = PythonOperator(task_id="training_model_C",
                                      python_callable=_training_model)

    choose_best_model = BranchPythonOperator(
        task_id="choose_best_model", python_callable=_choose_best_model)

    accurate = BashOperator(task_id="accurate", bash_command="echo 'accurate'")

    inaccurate = BashOperator(task_id="inaccurate",
                              bash_command="echo 'inaccurate'")
示例#16
0
    default_args=args,
    tags=['example']
)


def should_run(**kwargs):
    """
    Determine which dummy_task should be run based on if the execution date minute is even or odd.

    :param dict kwargs: Context
    :return: Id of the task to run
    :rtype: str
    """
    print('------------- exec dttm = {} and minute = {}'.
          format(kwargs['execution_date'], kwargs['execution_date'].minute))
    if kwargs['execution_date'].minute % 2 == 0:
        return "dummy_task_1"
    else:
        return "dummy_task_2"


cond = BranchPythonOperator(
    task_id='condition',
    python_callable=should_run,
    dag=dag,
)

dummy_task_1 = DummyOperator(task_id='dummy_task_1', dag=dag)
dummy_task_2 = DummyOperator(task_id='dummy_task_2', dag=dag)
cond >> [dummy_task_1, dummy_task_2]
示例#17
0
          "8 AM - 12 PM: Run & Workout\n"
          "12 PM - 10 PM: Code & Relax")


def weekday():
    print("Schedule:\n"
          "6 AM - 9 AM: Run & Workout\n"
          "9 AM - 5 PM: Work\n"
          "5 PM - 10 PM: Code & Relax")


default_args = {'owner': 'airflow', 'depends_on_past': False}

with DAG(dag_id="branch",
         description="A DAG that branches",
         default_args=default_args,
         dagrun_timeout=timedelta(hours=2),
         start_date=days_ago(1),
         schedule_interval="@daily",
         default_view="graph",
         is_paused_upon_creation=False,
         tags=["sample", "branch", "python"]) as dag:
    branch_task = BranchPythonOperator(task_id='branch',
                                       python_callable=branch)
    weekend_task = PythonOperator(task_id='weekend_task',
                                  python_callable=weekend)
    weekday_task = PythonOperator(task_id='weekday_task',
                                  python_callable=weekday)

    branch_task >> [weekend_task, weekday_task]
        for item in rows:
            dataset_total_rows = item[0]
            if dataset_total_rows == mysql_total_rows:
                print('bigquery_is_up_to_date')
                return 'bigquery_is_up_to_date'
            elif dataset_total_rows < mysql_total_rows:
                print('rows in bq dataset is not up to date to mysql')
                return 'extract_mysql_to_local_pq'
    except GoogleAPIError:
        rows.error_result['reason'] == 'notFound'
        print(rows.error_result)
        return 'extract_mysql_to_local_pq'


check_data = BranchPythonOperator(task_id='check_data',
                                  python_callable=check_data,
                                  op_kwargs=func_param,
                                  dag=dag)

download_zip = PythonOperator(task_id='download_zip',
                              python_callable=retrive_csv_file,
                              op_kwargs=func_param,
                              trigger_rule='none_skipped',
                              dag=dag)

load_to_mysql = PythonOperator(task_id='load_to_mysql',
                               python_callable=csv_load_to_db,
                               op_kwargs=func_param,
                               trigger_rule='none_failed_or_skipped',
                               dag=dag)

check_dataset = BranchPythonOperator(task_id='check_dataset',
示例#19
0
    upload_files = PythonOperator(
        task_id="upload_files",
        provide_context=True,
        python_callable=upload_remote_files,
    )

    build_message = PythonOperator(
        task_id="build_message",
        provide_context=True,
        python_callable=build_notification_message,
    )

    branching = BranchPythonOperator(
        task_id="branching",
        provide_context=True,
        python_callable=return_branch,
    )

    send_notification = PythonOperator(
        task_id="send_notification",
        provide_context=True,
        python_callable=send_success_msg,
    )

    no_notification = DummyOperator(task_id="no_need_for_notification")

    load_files >> get_packages >> upload_files >> build_message >> branching
    branching >> send_notification
    branching >> no_notification
seven_days_ago = datetime.combine(datetime.today() - timedelta(7),
                                  datetime.min.time())
args = {
    'owner': 'airflow',
    'start_date': seven_days_ago,
}

dag = DAG(dag_id='example_branch_operator',
          default_args=args,
          schedule_interval="@daily")

cmd = 'ls -l'
run_this_first = DummyOperator(task_id='run_this_first', dag=dag)

options = ['branch_a', 'branch_b', 'branch_c', 'branch_d']

branching = BranchPythonOperator(
    task_id='branching',
    python_callable=lambda: random.choice(options),
    dag=dag)
branching.set_upstream(run_this_first)

join = DummyOperator(task_id='join', trigger_rule='one_success', dag=dag)

for option in options:
    t = DummyOperator(task_id=option, dag=dag)
    t.set_upstream(branching)
    dummy_follow = DummyOperator(task_id='follow_' + option, dag=dag)
    t.set_downstream(dummy_follow)
    dummy_follow.set_downstream(join)
示例#21
0
    src = DownloadFileOperator(
        task_id="get_data",
        file_url=SRC_URL,
        dir=tmp_folder,
        filename="src_data.json",
    )

    get_package = GetPackageOperator(
        task_id="get_package",
        address=ckan_address,
        apikey=ckan_apikey,
        package_name_or_id=PACKAGE_NAME,
    )

    res_new_or_existing = BranchPythonOperator(
        task_id="res_new_or_existing",
        python_callable=is_resource_new,
    )

    transformed_data = PythonOperator(
        task_id="transform_data",
        python_callable=transform_data,
    )

    create_data_dictionary = PythonOperator(
        task_id="create_data_dictionary",
        python_callable=build_data_dict,
    )

    get_or_create_resource = GetOrCreateResourceOperator(
        task_id="get_or_create_resource",
        address=ckan_address,
示例#22
0
"""
Example DAG demonstrating a workflow with nested branching. The join tasks are created with
``none_failed_or_skipped`` trigger rule such that they are skipped whenever their corresponding
``BranchPythonOperator`` are skipped.
"""

from airflow.models import DAG
from airflow.operators.dummy import DummyOperator
from airflow.operators.python import BranchPythonOperator
from airflow.utils.dates import days_ago

with DAG(dag_id="example_nested_branch_dag",
         start_date=days_ago(2),
         schedule_interval="@daily",
         tags=["example"]) as dag:
    branch_1 = BranchPythonOperator(task_id="branch_1",
                                    python_callable=lambda: "true_1")
    join_1 = DummyOperator(task_id="join_1",
                           trigger_rule="none_failed_or_skipped")
    true_1 = DummyOperator(task_id="true_1")
    false_1 = DummyOperator(task_id="false_1")
    branch_2 = BranchPythonOperator(task_id="branch_2",
                                    python_callable=lambda: "true_2")
    join_2 = DummyOperator(task_id="join_2",
                           trigger_rule="none_failed_or_skipped")
    true_2 = DummyOperator(task_id="true_2")
    false_2 = DummyOperator(task_id="false_2")
    false_3 = DummyOperator(task_id="false_3")

    branch_1 >> true_1 >> join_1
    branch_1 >> false_1 >> branch_2 >> [true_2, false_2
                                        ] >> join_2 >> false_3 >> join_1
示例#23
0
def branch_func(**context):
    if random.random() < 0.5:
        return 'say_hi'
    return 'say_hello'


run_this_task = PythonOperator(task_id='run_this',
                               python_callable=push_to_xcom,
                               provide_context=True,
                               retries=10,
                               retry_delay=timedelta(seconds=1),
                               dag=dag)

run_this_task_2 = PythonOperator(task_id='say_hi',
                                 python_callable=print_hi,
                                 provide_context=True,
                                 dag=dag)

run_this_task_3 = PythonOperator(task_id='say_hello',
                                 python_callable=print_hello,
                                 provide_context=True,
                                 dag=dag)

branch_op = BranchPythonOperator(task_id='branch_task',
                                 python_callable=branch_func,
                                 provide_context=True,
                                 dag=dag)

run_this_task >> branch_op >> [run_this_task_2, run_this_task_3]
示例#24
0
with DAG('xcom_dag',
         schedule_interval='@daily',
         default_args=default_args,
         catchup=False) as dag:

    downloading_data = BashOperator(task_id='downloading_data',
                                    bash_command='sleep 3',
                                    do_xcom_push=False)

    with TaskGroup('processing_tasks') as processing_tasks:
        training_model_a = PythonOperator(task_id='training_model_a',
                                          python_callable=_training_model)

        training_model_b = PythonOperator(task_id='training_model_b',
                                          python_callable=_training_model)

        training_model_c = PythonOperator(task_id='training_model_c',
                                          python_callable=_training_model)

    choose_model = PythonOperator(task_id='task_4',
                                  python_callable=_choose_best_model)

    is_accurate = BranchPythonOperator(task_id='is_accurate',
                                       python_callable=_is_accurate)

    accurate = DummyOperator(task_id='accurate')

    inaccurate = DummyOperator(task_id='inaccurate')

    downloading_data >> processing_tasks >> choose_model
    choose_model >> is_accurate >> [accurate, inaccurate]
    print('This is even ::' + str(n))


def methodPrint(n):
    print('This is odd ::' + str(n))


def the_end():
    print('The End')


with DAG(dag_id='TaskGroup_BranchPythonOperator',
         schedule_interval=None,
         start_date=days_ago(2)) as dag:
    task_1 = PythonOperator(task_id='task_1', python_callable=method1)
    task_2 = BranchPythonOperator(task_id='task_2', python_callable=method2)
    with TaskGroup('group1') as group1:

        task_x = PythonOperator(task_id='task_x',
                                python_callable=printMethod,
                                op_kwargs={'n': 1})
        task_n = [
            PythonOperator(task_id=f'task_{i}',
                           python_callable=printMethod,
                           op_kwargs={'n': i}) for i in range(2, 6)
        ]
        task_x >> task_n

    with TaskGroup('group2') as group2:

        task_x = PythonOperator(task_id='task_x',
示例#26
0
def _fetch_dataset_old():
    print("Fetching data (OLD)...")


def _fetch_dataset_new():
    print("Fetching data (NEW)...")


with DAG(
        dag_id="03_branching",
        start_date=airflow.utils.dates.days_ago(3),
        schedule_interval="@daily",
) as dag:
    start = DummyOperator(task_id="start")

    pick_branch = BranchPythonOperator(task_id="pick_branch",
                                       python_callable=_pick_branch)

    fetch_dataset_old = PythonOperator(task_id="fetch_dataset_old",
                                       python_callable=_fetch_dataset_old)

    fetch_dataset_new = PythonOperator(task_id="fetch_dataset_new",
                                       python_callable=_fetch_dataset_new)

    fetch_another_dataset = DummyOperator(task_id="fetch_another_dataset")

    join_datasets = DummyOperator(task_id="join_datasets",
                                  trigger_rule="none_failed")

    train_model = DummyOperator(task_id="train_model")
    deploy_model = DummyOperator(task_id="deploy_model")
def _is_latest_run(**context):
    now = pendulum.now("UTC")
    left_window = context["dag"].following_schedule(context["execution_date"])
    right_window = context["dag"].following_schedule(left_window)
    return left_window < now <= right_window


with DAG(
        dag_id="05_condition_function",
        start_date=airflow.utils.dates.days_ago(3),
        schedule_interval="@daily",
) as dag:
    start = DummyOperator(task_id="start")

    pick_erp = BranchPythonOperator(task_id="pick_erp_system",
                                    python_callable=_pick_erp_system)

    fetch_sales_old = DummyOperator(task_id="fetch_sales_old")
    clean_sales_old = DummyOperator(task_id="clean_sales_old")

    fetch_sales_new = DummyOperator(task_id="fetch_sales_new")
    clean_sales_new = DummyOperator(task_id="clean_sales_new")

    join_erp = DummyOperator(task_id="join_erp_branch",
                             trigger_rule="none_failed")

    fetch_weather = DummyOperator(task_id="fetch_weather")
    clean_weather = DummyOperator(task_id="clean_weather")

    join_datasets = DummyOperator(task_id="join_datasets")
    train_model = DummyOperator(task_id="train_model")
示例#28
0
    )

    hive_s3_location = QuboleOperator(
        task_id='hive_s3_location',
        command_type="hivecmd",
        script_location=
        "s3n://public-qubole/qbol-library/scripts/show_table.hql",
        notify=True,
        tags=['tag1', 'tag2'],
        # If the script at s3 location has any qubole specific macros to be replaced
        # macros='[{"date": "{{ ds }}"}, {"name" : "abc"}]',
    )

    options = ['hadoop_jar_cmd', 'presto_cmd', 'db_query', 'spark_cmd']

    branching = BranchPythonOperator(
        task_id='branching', python_callable=lambda: random.choice(options))

    [hive_show_table, hive_s3_location] >> compare_result(
        hive_s3_location, hive_show_table) >> branching

    join = DummyOperator(task_id='join', trigger_rule=TriggerRule.ONE_SUCCESS)

    hadoop_jar_cmd = QuboleOperator(
        task_id='hadoop_jar_cmd',
        command_type='hadoopcmd',
        sub_command='jar s3://paid-qubole/HadoopAPIExamples/'
        'jars/hadoop-0.20.1-dev-streaming.jar '
        '-mapper wc '
        '-numReduceTasks 0 -input s3://paid-qubole/HadoopAPITests/'
        'data/3.tsv -output '
        's3://paid-qubole/HadoopAPITests/data/3_wc',
         'channelId': YOUTUBE_CHANNEL_ID,
         'maxResults': 50,
         'publishedAfter': YOUTUBE_VIDEO_PUBLISHED_AFTER,
         'publishedBefore': YOUTUBE_VIDEO_PUBLISHED_BEFORE,
         'type': 'video',
         'fields': 'items/id/videoId'
     },
     google_api_response_via_xcom='video_ids_response',
     s3_destination_key=f'{s3_directory}/youtube_search_{s3_file_name}.json',
     task_id='video_ids_to_s3')
 # [END howto_operator_google_api_to_s3_transfer_advanced_task_1]
 # [START howto_operator_google_api_to_s3_transfer_advanced_task_1_1]
 task_check_and_transform_video_ids = BranchPythonOperator(
     python_callable=_check_and_transform_video_ids,
     op_args=[
         task_video_ids_to_s3.google_api_response_via_xcom,
         task_video_ids_to_s3.task_id
     ],
     task_id='check_and_transform_video_ids')
 # [END howto_operator_google_api_to_s3_transfer_advanced_task_1_1]
 # [START howto_operator_google_api_to_s3_transfer_advanced_task_2]
 task_video_data_to_s3 = GoogleApiToS3Operator(
     gcp_conn_id=YOUTUBE_CONN_ID,
     google_api_service_name='youtube',
     google_api_service_version='v3',
     google_api_endpoint_path='youtube.videos.list',
     google_api_endpoint_params={
         'part': YOUTUBE_VIDEO_PARTS,
         'maxResults': 50,
         'fields': YOUTUBE_VIDEO_FIELDS
     },
示例#30
0
def create_dag(dag_id, dataset, schedule, default_args):

    dag = DAG(dag_id,
              schedule_interval=schedule,
              default_args=default_args,
              tags=['police', 'gary'])

    with dag:
        # init vars
        data_filename = "data.json"
        agol_dataset = dataset["agol_dataset"]
        package_name = dataset['package_id']
        resource_name = dataset["name"]

        tmp_dir = CreateLocalDirectoryOperator(task_id="tmp_dir",
                                               path=TMP_DIR / dag_id)

        fields_filepath = TMP_DIR / dag_id / "fields.json"

        get_package = GetPackageOperator(
            task_id="get_package",
            address=CKAN,
            apikey=CKAN_APIKEY,
            package_name_or_id=package_name,
        )

        def is_resource_new(**kwargs):
            resource = kwargs["ti"].xcom_pull(
                task_ids="get_or_create_resource")

            if resource["is_new"]:
                return "new_resource"

            return "existing_resource"

        new_or_existing = BranchPythonOperator(
            task_id="new_or_existing",
            python_callable=is_resource_new,
        )

        new_resource = DummyOperator(task_id="new_resource", dag=dag)
        existing_resource = DummyOperator(task_id="existing_resource", dag=dag)

        get_agol_data = AGOLDownloadFileOperator(
            task_id="get_agol_data",
            request_url=base_url + agol_dataset + "/FeatureServer/0/",
            dir=TMP_DIR / dag_id,
            filename=data_filename,
            delete_col=["geometry", "objectid"]
            #on_success_callback=task_success_slack_alert,
        )

        get_or_create_resource = GetOrCreateResourceOperator(
            task_id="get_or_create_resource",
            address=CKAN,
            apikey=CKAN_APIKEY,
            package_name_or_id=package_name,
            resource_name=resource_name,
            resource_attributes=dict(
                format="csv",
                package_id=package_name,
                is_preview=True,
                url_type="datastore",
                extract_job=f"Airflow: {package_name}",
            ),
        )

        def set_fields(**kwargs):
            ti = kwargs["ti"]
            tmp_dir = Path(ti.xcom_pull(task_ids="tmp_dir"))
            with open(fields_filepath, 'w') as fields_json_file:
                json.dump(fields[agol_dataset], fields_json_file)

            return fields_filepath

        set_fields = PythonOperator(task_id="set_fields",
                                    python_callable=set_fields,
                                    trigger_rule="none_failed")

        backup_resource = BackupDatastoreResourceOperator(
            task_id="backup_resource",
            address=CKAN,
            apikey=CKAN_APIKEY,
            resource_task_id="get_or_create_resource",
            dir_task_id="tmp_dir")

        # delete_records = DeleteDatastoreResourceOperator(
        #     task_id="delete_records",
        #     address = CKAN,
        #     apikey = CKAN_APIKEY,
        #     resource_id_task_id = "get_or_create_resource",
        #     resource_id_task_key = "id",
        #     trigger_rule="none_failed",
        # )

        delete_records = DeleteDatastoreResourceRecordsOperator(
            task_id="delete_records",
            address=CKAN,
            apikey=CKAN_APIKEY,
            backup_task_id="backup_resource",
        )

        insert_records = InsertDatastoreResourceRecordsFromJSONOperator(
            task_id="insert_records",
            address=CKAN,
            apikey=CKAN_APIKEY,
            resource_id_task_id="get_or_create_resource",
            resource_id_task_key="id",
            data_path=TMP_DIR / dag_id / data_filename,
            fields_path=fields_filepath)

        modify_metadata = EditResourceMetadataOperator(
            task_id="modify_metadata",
            address=CKAN,
            apikey=CKAN_APIKEY,
            resource_id_task_id="get_or_create_resource",
            resource_id_task_key="id",
            new_resource_name=resource_name,
            last_modified_task_id="get_agol_data",
            last_modified_task_key="last_modified")

        restore_backup = RestoreDatastoreResourceBackupOperator(
            task_id="restore_backup",
            address=CKAN,
            apikey=CKAN_APIKEY,
            backup_task_id="backup_resource",
        )

        delete_tmp_dir = DeleteLocalDirectoryOperator(
            task_id="delete_tmp_dir",
            path=TMP_DIR / dag_id,
            #on_success_callback=task_success_slack_alert,
        )

        job_success = DummyOperator(task_id="job_success",
                                    trigger_rule="all_success")

        job_failed = DummyOperator(task_id="job_failed",
                                   trigger_rule="one_failed")

        message_slack_success = GenericSlackOperator(
            task_id="message_slack_success",
            message_header=dag_id + " run successfully.",
            message_content_task_id="insert_records",
            message_content_task_key="data_inserted")

        message_slack_abort = GenericSlackOperator(
            task_id="message_slack_abort",
            message_header=dag_id + " failed, data records not changed",
        )

        message_slack_recover = GenericSlackOperator(
            task_id="message_slack_recover",
            message_header=dag_id + " failed to update, data records restored",
        )

        join_or = DummyOperator(task_id="join_or",
                                dag=dag,
                                trigger_rule="one_success")
        join_and = DummyOperator(task_id="join_and",
                                 dag=dag,
                                 trigger_rule="all_success")

        ## DAG EXECUTION LOGIC
        tmp_dir >> get_agol_data >> join_and
        get_package >> get_or_create_resource >> new_or_existing >> [
            new_resource, existing_resource
        ]
        new_resource >> set_fields >> join_or >> join_and
        existing_resource >> backup_resource >> delete_records >> join_or >> join_and
        join_and >> insert_records >> modify_metadata >> job_success >> delete_tmp_dir >> message_slack_success
        [get_agol_data, get_or_create_resource
         ] >> job_failed >> message_slack_abort
        [insert_records
         ] >> job_failed >> restore_backup >> message_slack_recover

    return dag