def test_dagrun_success_when_all_skipped(self): """ Tests that a DAG run succeeds when all tasks are skipped """ dag = DAG(dag_id='test_dagrun_success_when_all_skipped', start_date=timezone.datetime(2017, 1, 1)) dag_task1 = ShortCircuitOperator(task_id='test_short_circuit_false', dag=dag, python_callable=lambda: False) dag_task2 = DummyOperator(task_id='test_state_skipped1', dag=dag) dag_task3 = DummyOperator(task_id='test_state_skipped2', dag=dag) dag_task1.set_downstream(dag_task2) dag_task2.set_downstream(dag_task3) initial_task_states = { 'test_short_circuit_false': State.SUCCESS, 'test_state_skipped1': State.SKIPPED, 'test_state_skipped2': State.SKIPPED, } dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states) dag_run.update_state() self.assertEqual(State.SUCCESS, dag_run.state)
def test_clear_skipped_downstream_task(self): """ After a downstream task is skipped by ShortCircuitOperator, clearing the skipped task should not cause it to be executed. """ dag = DAG( 'shortcircuit_clear_skipped_downstream_task', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE }, schedule_interval=INTERVAL, ) short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: False) downstream = DummyOperator(task_id='downstream', dag=dag) short_op >> downstream dag.clear() dr = dag.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) downstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'downstream': assert ti.state == State.SKIPPED else: raise ValueError(f'Invalid task id {ti.task_id} found!') # Clear downstream with create_session() as session: clear_task_instances([t for t in tis if t.task_id == "downstream"], session=session, dag=dag) # Run downstream again downstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # Check if the states are correct. for ti in dr.get_task_instances(): if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'downstream': assert ti.state == State.SKIPPED else: raise ValueError(f'Invalid task id {ti.task_id} found!')
def test_get_task_instance_on_empty_dagrun(self): """ Make sure that a proper value is returned when a dagrun has no task instances """ dag = DAG(dag_id='test_get_task_instance_on_empty_dagrun', start_date=timezone.datetime(2017, 1, 1)) ShortCircuitOperator(task_id='test_short_circuit_false', dag=dag, python_callable=lambda: False) session = settings.Session() now = timezone.utcnow() # Don't use create_dagrun since it will create the task instances too which we # don't want dag_run = models.DagRun( dag_id=dag.dag_id, run_type=DagRunType.MANUAL, execution_date=now, start_date=now, state=State.RUNNING, external_trigger=False, ) session.add(dag_run) session.commit() ti = dag_run.get_task_instance('test_short_circuit_false') assert ti is None
def test_with_dag_run(self): value = False dag = DAG('shortcircuit_operator_test_with_dag_run', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE }, schedule_interval=INTERVAL) short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: value) branch_1 = DummyOperator(task_id='branch_1', dag=dag) branch_1.set_upstream(short_op) branch_2 = DummyOperator(task_id='branch_2', dag=dag) branch_2.set_upstream(branch_1) upstream = DummyOperator(task_id='upstream', dag=dag) upstream.set_downstream(short_op) dag.clear() logging.error("Tasks %s", dag.tasks) dr = dag.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING ) upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() self.assertEqual(len(tis), 4) for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'upstream': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': self.assertEqual(ti.state, State.SKIPPED) else: raise ValueError(f'Invalid task id {ti.task_id} found!') value = True dag.clear() dr.verify_integrity() upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() self.assertEqual(len(tis), 4) for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'upstream': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': self.assertEqual(ti.state, State.NONE) else: raise ValueError(f'Invalid task id {ti.task_id} found!')
def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" value = False dag = DAG( 'shortcircuit_operator_test_without_dag_run', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE }, schedule_interval=INTERVAL, ) short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: value) branch_1 = DummyOperator(task_id='branch_1', dag=dag) branch_1.set_upstream(short_op) branch_2 = DummyOperator(task_id='branch_2', dag=dag) branch_2.set_upstream(branch_1) upstream = DummyOperator(task_id='upstream', dag=dag) upstream.set_downstream(short_op) dag.clear() short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: tis = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.execution_date == DEFAULT_DATE) for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'upstream': # should not exist raise ValueError(f'Invalid task id {ti.task_id} found!') elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': assert ti.state == State.SKIPPED else: raise ValueError(f'Invalid task id {ti.task_id} found!') value = True dag.clear() short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'upstream': # should not exist raise ValueError(f'Invalid task id {ti.task_id} found!') elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': assert ti.state == State.NONE else: raise ValueError(f'Invalid task id {ti.task_id} found!')
from airflow.utils import dates args = { 'owner': 'airflow', } dag = DAG( dag_id='example_short_circuit_operator', default_args=args, start_date=dates.days_ago(2), tags=['example'], ) cond_true = ShortCircuitOperator( task_id='condition_is_True', python_callable=lambda: True, dag=dag, ) cond_false = ShortCircuitOperator( task_id='condition_is_False', python_callable=lambda: False, dag=dag, ) ds_true = [DummyOperator(task_id='true_' + str(i), dag=dag) for i in [1, 2]] ds_false = [DummyOperator(task_id='false_' + str(i), dag=dag) for i in [1, 2]] chain(cond_true, *ds_true) chain(cond_false, *ds_false)
locate_file_cmd = """ sleep 10 find {{params.source_location}} -type f -printf "%f\n" | head -1 """ t_view = BashOperator( task_id="view_file", bash_command=locate_file_cmd, do_xcom_push=True, params={"source_location": "/your/input_dir/path"}, dag=dag, ) t_is_data_available = ShortCircuitOperator( task_id="check_if_data_available", python_callable=lambda task_output: not task_output == "", op_kwargs=dict(task_output=t_view.output), dag=dag, ) t_move = DockerOperator( api_version="1.19", docker_url="tcp://localhost:2375", # replace it with swarm/docker endpoint image="centos:latest", network_mode="bridge", mounts=[ Mount(source="/your/host/input_dir/path", target="/your/input_dir/path", type="bind"), Mount(source="/your/host/output_dir/path", target="/your/output_dir/path", type="bind"),
task_id=f'export_missing_premsa_{gene}', python_callable=export_sequences, op_kwargs={ "gene" : gene, "output_fn" : filepath }, dag=dag, ) pre_msa = BashOperator( task_id=f'pre_msa_{gene}', bash_command=PREMSA, params={'regions': regions, 'filepath': filepath, 'gene': gene, 'node' : i % 8, 'stdout' : stdout }, dag=dag, ) populated_check_task = ShortCircuitOperator( task_id=f'check_if_populated_{gene}', python_callable=is_export_populated, op_kwargs={ 'filepath': filepath }, dag=dag ) # Store nuc_input, prot_input, type import_premsa_seqs = PythonOperator( task_id=f'store_premsa_{gene}', python_callable=store_premsa_file, op_kwargs={ "nuc_input" : nuc_input_filepath, "prot_input" : prot_input_filepath, "gene": gene }, dag=dag, ) mark_troubled_task = PythonOperator( task_id=f'mark_troubled_{gene}', python_callable=mark_troubled, op_kwargs={ "log_file" : stdout, "gene": gene },
from airflow import DAG from airflow.models.baseoperator import chain from airflow.operators.empty import EmptyOperator from airflow.operators.python import ShortCircuitOperator from airflow.utils.trigger_rule import TriggerRule with DAG( dag_id='example_short_circuit_operator', start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, tags=['example'], ) as dag: # [START howto_operator_short_circuit] cond_true = ShortCircuitOperator( task_id='condition_is_True', python_callable=lambda: True, ) cond_false = ShortCircuitOperator( task_id='condition_is_False', python_callable=lambda: False, ) ds_true = [EmptyOperator(task_id='true_' + str(i)) for i in [1, 2]] ds_false = [EmptyOperator(task_id='false_' + str(i)) for i in [1, 2]] chain(cond_true, *ds_true) chain(cond_false, *ds_false) # [END howto_operator_short_circuit] # [START howto_operator_short_circuit_trigger_rules]
bash_command=locate_file_cmd, do_xcom_push=True, params={"source_location": "/your/input_dir/path"}, dag=dag, ) def is_data_available(*args, **kwargs): """Return True if data exists in XCom table for view_file task, false otherwise.""" ti = kwargs["ti"] data = ti.xcom_pull(key=None, task_ids="view_file") return not data == "" t_is_data_available = ShortCircuitOperator( task_id="check_if_data_available", python_callable=is_data_available, dag=dag ) t_move = DockerOperator( api_version="1.19", docker_url="tcp://localhost:2375", # replace it with swarm/docker endpoint image="centos:latest", network_mode="bridge", volumes=[ "/your/host/input_dir/path:/your/input_dir/path", "/your/host/output_dir/path:/your/output_dir/path", ], command=[ "/bin/bash", "-c", "/bin/sleep 30; "