def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) self.dag.clear() self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: tis = session.query(TI).filter( TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE ) for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': # should exist with state None self.assertEqual(ti.state, State.NONE) elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.SKIPPED) else: raise Exception
def create_dag(dag_id, value): def run_print_var(): return "go_fail" default_args = { 'owner': 'kwas', 'start_date': datetime(2018, 9, 6), 'var': 'default' } dag = DAG(dag_id, default_args=default_args) print_date = BashOperator(task_id='print_date', bash_command='date', dag=dag) branch = BranchPythonOperator(task_id='branch', python_callable=run_print_var, dag=dag) branch.set_upstream(print_date) fail = BashOperator( task_id='go_fail', bash_command='if [ ! -f /tmp/kwas-fail ]; then exit 1; fi', dag=dag) fail.set_upstream(branch) finish = BashOperator(task_id='final_task', bash_command='echo finish', trigger_rule='all_success', dag=dag) finish.set_upstream(fail) return dag
def test_with_skip_in_branch_downstream_dependencies2(self): self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_2') self.branch_op >> self.branch_1 >> self.branch_2 self.branch_op >> self.branch_2 self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING ) self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': self.assertEqual(ti.state, State.SKIPPED) elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.NONE) else: raise Exception
def test_branch_list_without_dag_run(self): """This checks if the BranchPythonOperator supports branching off to a list of tasks.""" self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: ['branch_1', 'branch_2']) self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag) self.branch_3.set_upstream(self.branch_op) self.dag.clear() self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: tis = session.query(TI).filter( TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE ) expected = { "make_choice": State.SUCCESS, "branch_1": State.NONE, "branch_2": State.NONE, "branch_3": State.SKIPPED, } for ti in tis: if ti.task_id in expected: self.assertEqual(ti.state, expected[ti.task_id]) else: raise Exception
def extract_network_externals(parent_dag_name, child_dag_name, start_date, schedule_interval): """ Live network external definitions for all vendors :param parent_dag_name: :param child_dag_name: :param start_date: :param schedule_interval: :return: """ dag = DAG( '%s.%s' % (parent_dag_name, child_dag_name), schedule_interval=schedule_interval, start_date=start_date, ) branch_externals_task = DummyOperator( task_id='branch_externals', dag=dag) join_externals_task = DummyOperator(task_id='join_externals', dag=dag) def extract_external_definitions_on_ericsson(): ericsson_cm.extract_live_network_externals_on_2g() ericsson_cm.extract_live_network_externals_on_3g() ericsson_cm.extract_live_network_externals_on_4g() extract_external_definitions_on_ericsson_task = BranchPythonOperator( task_id='extract_external_definitions_on_ericsson', python_callable=extract_external_definitions_on_ericsson, dag=dag) def extract_external_definitions_on_huawei(): huawei_cm.extract_live_network_externals_on_2g() huawei_cm.extract_live_network_externals_on_3g() huawei_cm.extract_live_network_externals_on_4g() extract_external_definitions_on_huawei_task = BranchPythonOperator( task_id='extract_external_definitions_on_huawei', python_callable=extract_external_definitions_on_huawei, dag=dag) def extract_external_definitions_on_zte(): zte_cm.extract_live_network_externals_on_2g() zte_cm.extract_live_network_externals_on_3g() zte_cm.extract_live_network_externals_on_4g() extract_external_definitions_on_zte_task = BranchPythonOperator( task_id='extract_external_definitions_on_zte', python_callable=extract_external_definitions_on_zte, dag=dag) dag.set_dependency('branch_externals', 'extract_external_definitions_on_ericsson') dag.set_dependency('branch_externals', 'extract_external_definitions_on_huawei') dag.set_dependency('branch_externals', 'extract_external_definitions_on_zte') dag.set_dependency('extract_external_definitions_on_ericsson', 'join_externals') dag.set_dependency('extract_external_definitions_on_huawei', 'join_externals') dag.set_dependency('extract_external_definitions_on_zte', 'join_externals') return dag
class BranchOperatorTest(unittest.TestCase): def setUp(self): self.dag = DAG('branch_operator_test', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE}, schedule_interval=INTERVAL) self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag) self.branch_1.set_upstream(self.branch_op) self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag) self.branch_2.set_upstream(self.branch_op) self.dag.clear() def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) session = Session() tis = session.query(TI).filter( TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE ) session.close() for ti in tis: if ti.task_id == 'make_choice': self.assertEquals(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': # should exist with state None self.assertEquals(ti.state, State.NONE) elif ti.task_id == 'branch_2': self.assertEquals(ti.state, State.SKIPPED) else: raise def test_with_dag_run(self): dr = self.dag.create_dagrun( run_id="manual__", start_date=datetime.datetime.now(), execution_date=DEFAULT_DATE, state=State.RUNNING ) self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': self.assertEquals(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': self.assertEquals(ti.state, State.NONE) elif ti.task_id == 'branch_2': self.assertEquals(ti.state, State.SKIPPED) else: raise
def get_grupo_dados(dag, previous_task, next_task, dados): for dado in dados: extracao = SimpleHttpOperator( task_id='Extracao_de_dados_{}'.format(dado), endpoint='url...', method='GET', trigger_rule="all_success", dag=dag) email_erro = EmailOperator( task_id='Email_Erro_{}'.format(dado), to='*****@*****.**', subject='Airflow Alert Erro', html_content='Erro ao realizar captura de {}'.format(dado), dag=dag, trigger_rule="all_failed", default_args={ 'email': ['*****@*****.**'], 'email_on_failure': True, 'email_on_retry': True, 'retries': 2, 'retry_delay': timedelta(minutes=5) }) salvar_base_raw = BranchPythonOperator( task_id='Salvar_DB_Raw_{}'.format(dado), python_callable=salva_dados_db_raw, trigger_rule="all_success", dag=dag) stop_falha = BranchPythonOperator( task_id='Stop_erro_extracao_{}'.format(dado), python_callable=salva_dados_db_raw, trigger_rule="dummy", dag=dag) transformacao = BranchPythonOperator( task_id='Transformacao_dados_{}'.format(dado), python_callable=transforma_dados, trigger_rule="one_success", dag=dag) salvar_base_staging = BranchPythonOperator( task_id='Salvar_DB_Staging_{}'.format(dado), python_callable=salva_dados_db_staging, trigger_rule="all_success", dag=dag) #definindo fluxo previous_task >> extracao extracao >> email_erro extracao >> salvar_base_raw email_erro >> stop_falha stop_falha >> transformacao salvar_base_raw >> transformacao transformacao >> salvar_base_staging salvar_base_staging >> next_task
def test_clear_skipped_downstream_task(self): """ After a downstream task is skipped by BranchPythonOperator, clearing the skipped task should not cause it to be executed. """ branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') branches = [self.branch_1, self.branch_2] branch_op >> branches self.dag.clear() dr = self.dag.create_dagrun(run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) for task in branches: task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.SKIPPED) else: raise children_tis = [ ti for ti in tis if ti.task_id in branch_op.get_direct_relative_ids() ] # Clear the children tasks. with create_session() as session: clear_task_instances(children_tis, session=session, dag=self.dag) # Run the cleared tasks again. for task in branches: task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # Check if the states are correct after children tasks are cleared. for ti in dr.get_task_instances(): if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.SKIPPED) else: raise
def __init__(self, args: Dict, parent_dag_id: str, child_dag_id: str, repository_class: TypeVar(TaskRepositoryMixin), engine: Engine = None): """ Defines subDAG tasks """ self._parent_dag_id = parent_dag_id self._child_dag_id = child_dag_id self._repository_class = repository_class self._engine = engine self._subdag = DAG( dag_id=f'{self._parent_dag_id}.{self._child_dag_id}', default_args=args, schedule_interval=None) self._initialize_task_operator = PythonOperator( task_id=f'initialize_{self._child_dag_id}', provide_context=True, python_callable=self._initialize_task, dag=self._subdag) self._conditional_operator = BranchPythonOperator( task_id=f'conditional_{self._child_dag_id}', provide_context=True, python_callable=self._execute_or_skip_task, dag=self._subdag) self._dummy_operator = DummyOperator( task_id=f'skip_{self._child_dag_id}', dag=self._subdag) self._start_task_in_db_operator = PythonOperator( task_id=f'start_task_in_db_{self._child_dag_id}', provide_context=True, python_callable=self._start_task, dag=self._subdag) self._parametrized_bash_operator = ParametrizedBashOperator( task_id=f'bash_{self._child_dag_id}', parameters_provider=self._parameters_provider, bash_command='echo', dag=self._subdag) self._finish_task_in_db_operator = PythonOperator( task_id=f'finish_task_in_db_{self._child_dag_id}', provide_context=True, python_callable=self._finish_task, dag=self._subdag) self._join_operator = DummyOperator( task_id=f'join_{self._child_dag_id}', trigger_rule='one_success', dag=self._subdag)
def setUp(self): self.dag = DAG('branch_operator_test', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE}, schedule_interval=INTERVAL) self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag) self.branch_1.set_upstream(self.branch_op) self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag) self.branch_2.set_upstream(self.branch_op) self.dag.clear()
def __init__(self, day, true_result_task, false_result_task, *args, **kwargs): BranchPythonOperator.__init__(self, task_id='check_month_day{}'.format(day), python_callable=check_month_day, provide_context=True, op_kwargs={ 'day': day, 'true_result_task': true_result_task, 'false_result_task': false_result_task }, *args, **kwargs)
def create_account_dags(self): accounts = self.__get_accounts() template_dag_id = "template_{}" for account_id in accounts: dag_id = template_dag_id.format(account_id) globals()[dag_id] = self.create_template_dag(dag_id) task_start = DummyOperator(task_id="start", dag=globals()[dag_id]) task_preprocess = self.create_template_task(dag_id, account_id, 'preprocess') task_train_skip = DummyOperator(task_id="train_skip", dag=globals()[dag_id]) task_train = self.create_template_task(dag_id, account_id, 'train')#, 'cloud') task_check_model = BranchPythonOperator(task_id="check_model", python_callable=check_branch, op_kwargs={'account_id': account_id, 'task_code': 'check_model', # DO NOT CALL THIS ARGUMENT task, IT IS RESERVED! '1': "train", '0': "train_skip", 'python_path':self.das_python_path, 'jobs_path':self.das_jobs_path, 'conf_code':DagBuilder.CONF_CODE}, provide_context=True, dag=globals()[dag_id]) task_predict = self.create_template_task(dag_id, account_id, 'predict') task_end = DummyOperator(task_id="end", dag=globals()[dag_id]) task_verify_model = BranchPythonOperator(task_id="verify_model", python_callable=check_branch, op_kwargs={'account_id': account_id, 'task_code': 'verify_model', '1': "end", '0': "predict", 'python_path':self.das_python_path, 'jobs_path':self.das_jobs_path, 'conf_code':DagBuilder.CONF_CODE}, provide_context=True, dag=globals()[dag_id], trigger_rule='none_failed') task_start >> task_preprocess >> task_check_model >> [ task_train_skip, task_train ] >> task_verify_model >> [ task_predict, task_end ] task_predict >> task_end
def create_dag(dag_id, start_date, table, database, schedule_interval=None): with DAG( dag_id=dag_id, start_date=start_date, schedule_interval=schedule_interval ) as dag: start = PythonOperator(task_id='start', python_callable=start_func(dag_id, database)) who_am_i = BashOperator(task_id='who_am_i', xcom_push=True, bash_command='whoami') check_exist = BranchPythonOperator(task_id="check_table_exist", python_callable=check_table_exist(schema=database, table=table)) create_table = PythonOperator(task_id='create_table', python_callable=create_table_func(schema=database, table=table)) do_nothing = DummyOperator(task_id='do_nothing') insert_new_row = PythonOperator(task_id='insert_new_row', provide_context=True, trigger_rule='all_done', python_callable=insert_row(schema=database, table=table)) query_the_table = PostgreSQLCountRows(task_id='query_the_table', schema=database, table=table) end = BashOperator(task_id='end', xcom_push=True, bash_command="echo '{{ run_id }} ended'") start >> who_am_i >> check_exist >> [create_table, do_nothing] >> insert_new_row >> query_the_table >> end return dag
def _get_subdag_test_dag(self): with DAG(dag_id='test_dag', default_args=DEFAULT_DAG_ARGS) as dag: def fn0(): math.pow(1, 2) def fn1(): print("hi") def fn2(): math.factorial(1) op1 = BranchPythonOperator(task_id='op1', python_callable=range) op2 = PythonOperator(task_id='op2', python_callable=fn0) op3 = PythonOperator(task_id='op3', python_callable=fn1) op4 = PythonOperator(task_id='op4', python_callable=fn2) op5 = DummyOperator(task_id='op5') op6 = PythonOperator(task_id='op6', python_callable=fn0) op7 = PythonOperator(task_id='op7', python_callable=fn1) op8 = PythonOperator(task_id='op8', python_callable=fn2) op9 = DummyOperator(task_id='op9') op1 >> [op2, op3] >> op4 op2 >> op5 >> [op6, op7] >> op8 [op4, op6, op8] >> op9 return dag
def _get_test_dag(self): with DAG(dag_id='test_dag', default_args=DEFAULT_DAG_ARGS) as dag: op1 = SparkSubmitOperator(task_id='op1') op2 = EmrAddStepsOperator(task_id='op2', job_flow_id='foo') op3 = S3ListOperator(task_id='op3', bucket='foo') op4 = EmrCreateJobFlowOperator(task_id='op4') op5 = TriggerDagRunOperator(task_id='op5', trigger_dag_id='foo') op6 = FileToWasbOperator(task_id='op6', container_name='foo', blob_name='foo', file_path='foo') op7 = EmailOperator(task_id='op7', subject='foo', to='foo', html_content='foo') op8 = S3CopyObjectOperator(task_id='op8', dest_bucket_key='foo', source_bucket_key='foo') op9 = BranchPythonOperator(task_id='op9', python_callable=print) op10 = PythonOperator(task_id='op10', python_callable=range) op1 >> [op2, op3, op4] op2 >> [op5, op6] op6 >> [op7, op8, op9] op3 >> [op7, op8] op8 >> [op9, op10] return dag
def get_decide_airflow_upgrade(self, task_id=dn.DECIDE_AIRFLOW_UPGRADE): """Generate the decide_airflow_upgrade step Step responsible for deciding whether to branch to the path to upgrade airflow worker """ def upgrade_airflow_check(**kwargs): """upgrade_airflow_check function Defines a function to decide whether to upgrade airflow worker. The decision will be based on the xcom value that is retrieved from the 'armada_post_apply' task """ # DAG ID will be parent + subdag name dag_id = self.parent_dag_name + '.' + dn.ARMADA_BUILD_DAG_NAME # Check if Shipyard/Airflow were upgraded by the workflow upgrade_airflow = kwargs['ti'].xcom_pull( key='upgrade_airflow_worker', task_ids='armada_post_apply', dag_id=dag_id) # Go to the branch to upgrade Airflow worker if the Shipyard # chart were upgraded/modified if upgrade_airflow == "true": return "upgrade_airflow" else: return "skip_upgrade_airflow" return BranchPythonOperator(task_id=task_id, python_callable=upgrade_airflow_check, trigger_rule="all_success", dag=self.dag)
def create_dag(): dag = DAG(DAG_ID, default_args=default_args, schedule_interval='@hourly', catchup=False) with dag: finish_task = DummyOperator(task_id='finish') pusher_task_id = f'schedule_df_wrench_to_lake' should_run_task = BranchPythonOperator(task_id='should_run', python_callable=should_run) schedule_df_task = ScheduleDataflowJobOperator( task_id=pusher_task_id, project=project_id, template_name='load_wrench_to_lake', job_name=f'wrench-to-lake', job_parameters={}, provide_context=True) monitor_df_job_task = DataflowJobStateSensor( task_id=f'monitor_df_job', pusher_task_id=pusher_task_id, poke_interval=airflow_vars['dags']['wrench_to_lake'] ['poke_interval'], timeout=airflow_vars['dags']['wrench_to_lake']['poke_timeout'], dag=dag) move_files_task = PythonOperator(task_id='move_processed_files', python_callable=move_files) should_run_task >> schedule_df_task >> monitor_df_job_task >> move_files_task >> finish_task return dag
def MakeMarkComplete(dag): """Make the final sequence of the daily graph.""" mark_complete = BranchPythonOperator( task_id='mark_complete', python_callable=ReportDailySuccessful, provide_context=True, dag=dag, ) tag_daily_grc = BashOperator( task_id='tag_daily_gcr', bash_command=gcr_tag_success, dag=dag, ) # skip_grc = DummyOperator( # task_id='skip_tag_daily_gcr', # dag=dag, # ) # end = DummyOperator( # task_id='end', # dag=dag, # trigger_rule="one_success", # ) mark_complete >> tag_daily_grc # mark_complete >> skip_grc >> end return mark_complete
def test_parent_skip_branch(): """ A simple DAG with a BranchPythonOperator that does not follow op2. NotPreviouslySkippedDep is not met. """ with create_session() as session: session.query(DagRun).delete() session.query(TaskInstance).delete() start_date = pendulum.datetime(2020, 1, 1) dag = DAG("test_parent_skip_branch_dag", schedule_interval=None, start_date=start_date) dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=start_date) op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op3", dag=dag) op2 = DummyOperator(task_id="op2", dag=dag) op3 = DummyOperator(task_id="op3", dag=dag) op1 >> [op2, op3] TaskInstance(op1, start_date).run() ti2 = TaskInstance(op2, start_date) dep = NotPreviouslySkippedDep() assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 1 session.commit() assert not dep.is_met(ti2, session) assert ti2.state == State.SKIPPED
def dag_preprocess_tables( dag_id, schedule_interval, start_date, target_project_id, target_dataset_id, table_config, table_partition, ): dag = DAG(dag_id=dag_id, schedule_interval=schedule_interval, start_date=start_date) for table in table_config: start_check_tables_task = DummyOperator( task_id='%s-%s' % ("start_check_tables_task", table["name"]), dag=dag) check_if_table_exist = BranchPythonOperator( task_id='%s-%s' % (table["name"], "check_if_table_exist"), python_callable=if_tbl_exists, op_kwargs={ 'dataset': target_dataset_id, 'project': target_project_id, 'table_name': table["name"] }, dag=dag) table_exists = DummyOperator(task_id='%s-%s' % (table["name"], "table_exists"), dag=dag) table_does_not_exist = DummyOperator( task_id='%s-%s' % (table["name"], "table_does_not_exist"), dag=dag) # [start create equipped_item_reference if not exists] create_if_not_exists = BigQueryCreateEmptyTableOperator( task_id='%s-%s' % (table["name"], "create_if_not_exists"), project_id=target_project_id, dataset_id=target_dataset_id, table_id=table["name"], gcs_schema_object=table["schema_gcs_location"], time_partitioning=table_partition, trigger_rule=TriggerRule.ALL_SUCCESS, dag=dag) end_check_tables_task = DummyOperator( task_id='%s-%s' % ("end_check_tables_task", table["name"]), trigger_rule='none_failed_or_skipped', dag=dag) start_check_tables_task >> check_if_table_exist >> [ table_does_not_exist, table_exists ] table_does_not_exist >> create_if_not_exists >> end_check_tables_task table_exists >> end_check_tables_task return dag
def my_sub_dag(parent_dag_id): # Step 1 - define the default parameters for the DAG default_args = { 'owner': 'Nitin Ware', 'depends_on_past': False, 'start_date': datetime(2019, 7, 28), 'email': ['*****@*****.**'], 'email_on_failure': False, 'email_on_retry': False, 'retries': 1, 'retry_delay': timedelta(minutes=5), } p_val = Variable.get('v_val') # Variable passed to registered method # Step 2 - Create a DAG object my_sub_dag = DAG(dag_id=parent_dag_id + '.' + 'my_sub_dag', schedule_interval='0 0 * * *', default_args=default_args) # Step 3 - Define the method to check the condition for branching def my_check_condition(**kwargs): if int(p_val) >= 15: return 'greater_Than_equal_to_15' else: return 'less_Than_15' # Step 4 - Create a Branching task to Register the method in step 3 to the branching API checkTask = BranchPythonOperator( task_id='check_task', python_callable=my_check_condition, # Registered method provide_context=True, dag=my_sub_dag) # Step 5 - Create tasks greaterThan15 = BashOperator( task_id='greater_Than_equal_to_15', bash_command="echo value is greater than or equal to 15", dag=my_sub_dag) lessThan15 = BashOperator(task_id='less_Than_15', bash_command="echo value is less than 15", dag=my_sub_dag) finalTask = BashOperator(task_id='join_task', bash_command="echo This is a join", trigger_rule=TriggerRule.ONE_SUCCESS, dag=my_sub_dag) # Step 6 - Define the sequence of tasks. lessThan15.set_upstream(checkTask) greaterThan15.set_upstream(checkTask) finalTask.set_upstream([lessThan15, greaterThan15]) # Step 7 - Return the DAG return my_sub_dag
def apply_task_to_dag(self): check_dags_queued_task = BranchPythonOperator( task_id=f'{self.task_id}-is-dag-queue-empty', python_callable=self.__queued_dag_runs_exists, provide_context=True, trigger_rule=TriggerRule.ALL_DONE, dag=self.dag) delete_stack_task = CloudFormationDeleteStackOperator( task_id=f'delete-cloudformation-{self.task_id}', params={'StackName': self.stack_name}, dag=self.dag) delete_stack_sensor = CloudFormationDeleteStackSensor( task_id=f'cloudformation-watch-{self.task_id}-delete', stack_name=self.stack_name, dag=self.dag) stack_delete_end_task = DummyOperator( task_id=f'delete-end-{self.task_id}', dag=self.dag) if self.parent: self.parent.set_downstream(check_dags_queued_task) check_dags_queued_task.set_downstream(stack_delete_end_task) check_dags_queued_task.set_downstream(delete_stack_task) delete_stack_task.set_downstream(delete_stack_sensor) delete_stack_sensor.set_downstream(stack_delete_end_task) return stack_delete_end_task
def apply_task_to_dag(self): check_cloudformation_stack_exists_task = BranchPythonOperator( templates_dict={'stack_name': self.stack_name}, task_id=f'is-cloudformation-{self.task_id}-running', python_callable=self.__cloudformation_stack_running_branch, provide_context=True, dag=self.dag) create_cloudformation_stack_task = CloudFormationCreateStackOperator( task_id=f'create-cloudformation-{self.task_id}', params={**self.__reformatted_params()}, dag=self.dag) create_stack_sensor_task = CloudFormationCreateStackSensor( task_id=f'cloudformation-watch-{self.task_id}-create', stack_name=self.stack_name, dag=self.dag) stack_creation_end_task = DummyOperator( task_id=f'creation-end-{self.task_id}', dag=self.dag, trigger_rule='all_done') if self.parent: self.parent.set_downstream(check_cloudformation_stack_exists_task) create_stack_sensor_task.set_downstream(stack_creation_end_task) create_cloudformation_stack_task.set_downstream( create_stack_sensor_task) check_cloudformation_stack_exists_task.set_downstream( create_cloudformation_stack_task) check_cloudformation_stack_exists_task.set_downstream( stack_creation_end_task) return stack_creation_end_task
def cell_image_analysis_generate_decide_run_cellprofiler( dag: DAG) -> List[BranchPythonOperator]: tasks = [] for well in plate_wells_384()[50:55]: tasks.append( BranchPythonOperator( dag=dag, task_id='cell_image_analysis_decide_run_cellprofiler_{}'. format(well), params={'well': well}, provide_context=True, python_callable=cell_image_analysis_decide_run_cellprofiler)) return tasks
def create_dag(): dag = DAG( DAG_ID, default_args=default_args, # Be sure to stagger the dags so they don't run all at once, # possibly causing max memory usage and pod failure. - Stu M. schedule_interval='0 * * * *', catchup=False) with dag: start_task = DummyOperator(task_id='start') finish_task = DummyOperator(task_id='finish') storage = CloudStorage.factory(project_id) cdc_imports_bucket = storage.get_bucket(bucket) cdc_imports_processed_bucket = storage.get_bucket(processed_bucket) for files_startwith, table in table_map.items(): pusher_task_id = f'schedule_df_gcs_to_lake_{table}' continue_if_file_task = BranchPythonOperator( task_id=f'continue_if_file_{files_startwith}', python_callable=should_continue, op_args=[files_startwith, cdc_imports_bucket, table]) schedule_df_task = ScheduleDataflowJobOperator( task_id=pusher_task_id, project=project_id, template_name=f'load_cdc_from_gcs_to_lake', job_name=f'gcs-to-lake-{table}', job_parameters={ 'files_startwith': files_startwith, 'dest': f'{project_id}:lake.{table}' }, provide_context=True) monitor_df_job_task = DataflowJobStateSensor( task_id=f'monitor_df_job_{table}', pusher_task_id=pusher_task_id, poke_interval=airflow_vars['dags']['cdc_from_gcs_to_lake'] ['poke_interval'], timeout=airflow_vars['dags']['cdc_from_gcs_to_lake'] ['poke_timeout'], dag=dag) move_files_task = PythonOperator( task_id=f'move_processed_files_{files_startwith}', python_callable=storage.move_files, op_args=[ files_startwith, cdc_imports_bucket, cdc_imports_processed_bucket ], ) (start_task >> continue_if_file_task >> schedule_df_task >> monitor_df_job_task >> move_files_task >> finish_task) return dag
def test_xcom_push(self): self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING ) self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': self.assertEqual( ti.xcom_pull(task_ids='make_choice'), 'branch_1')
def Decide(config) -> BaseOperator: def execute(config, **kwargs): # =================================== # Place your branching strategy here. Only the task ids that this function returns # will be executed. # It can return a single task_id as a string (e.g. `return "next_task_id"`) or # a list with many task ids (e.g. `return ["task_id_1", "task_id_2"]`). # =================================== return "Run_Workflow" return BranchPythonOperator(task_id="Decide", op_kwargs={"config": config}, python_callable=execute, trigger_rule="all_success")
def MakeMarkComplete(dag): """Make the final sequence of the daily graph.""" mark_complete = BranchPythonOperator( task_id='mark_complete', python_callable=ReportDailySuccessful, provide_context=True, dag=dag, ) gcr_tag_success = r""" {% set settings = task_instance.xcom_pull(task_ids='generate_workflow_args') %} set -x pwd; ls gsutil ls gs://{{ settings.GCS_FULL_STAGING_PATH }}/docker/ > docker_tars.txt cat docker_tars.txt | grep -Eo "docker\/(([a-z]|-)*).tar.gz" | \ sed -E "s/docker\/(([a-z]|-)*).tar.gz/\1/g" > docker_images.txt gcloud auth configure-docker -q cat docker_images.txt | \ while read -r docker_image;do pull_source="gcr.io/{{ settings.GCR_STAGING_DEST }}/${docker_image}:{{ settings.VERSION }}" push_dest=" gcr.io/{{ settings.GCR_STAGING_DEST }}/${docker_image}:latest_{{ settings.BRANCH }}"; docker pull $pull_source docker tag $pull_source $push_dest docker push $push_dest done cat docker_tars.txt docker_images.txt rm docker_tars.txt docker_images.txt """ tag_daily_grc = BashOperator( task_id='tag_daily_gcr', bash_command=gcr_tag_success, dag=dag, ) # skip_grc = DummyOperator( # task_id='skip_tag_daily_gcr', # dag=dag, # ) # end = DummyOperator( # task_id='end', # dag=dag, # trigger_rule="one_success", # ) mark_complete >> tag_daily_grc # mark_complete >> skip_grc >> end return mark_complete
def run_huawei_2g_parser(parent_dag_name, child_dag_name, start_date, schedule_interval): """ Parse huawei 2g cm files. :param parent_dag_name: :param child_dag_name: :param start_date: :param schedule_interval: :return: """ dag = DAG( '%s.%s' % (parent_dag_name, child_dag_name), schedule_interval=schedule_interval, start_date=start_date, ) def get_cm_file_format(): # if 'huawei_mml' return 'run_huawei_2g_mml_parser' t23 = BranchPythonOperator( task_id='branch_huawei_2g_parser', python_callable=get_cm_file_format, dag=dag) t29 = BashOperator( task_id='run_huawei_2g_xml_nbi_parser', bash_command='java -jar /mediation/bin/boda-huaweinbixmlparser.jar /mediation/data/cm/huawei/2g/raw/in /mediation/data/cm/huawei/2g/parsed/in /mediation/conf/cm/hua_cm_2g_nbi_parameters.cfg', dag=dag) t29_2 = BashOperator( task_id='run_huawei_2g_mml_parser', bash_command='java -jar /mediation/bin/boda-huaweimmlparser.jar /mediation/data/cm/huawei/2g/raw/in /mediation/data/cm/huawei/2g/parsed/in /mediation/conf/cm/hua_cm_2g_mml_parser.cfg', dag=dag) t_join = DummyOperator( task_id='join_huawei_2g_parser', dag=dag, ) dag.set_dependency('branch_huawei_2g_parser', 'run_huawei_2g_mml_parser') dag.set_dependency('branch_huawei_2g_parser', 'run_huawei_2g_xml_nbi_parser') dag.set_dependency('run_huawei_2g_mml_parser', 'join_huawei_2g_parser') dag.set_dependency('run_huawei_2g_xml_nbi_parser', 'join_huawei_2g_parser') return dag
def Branch_0(config) -> BaseOperator: def execute(config, **kwargs): # =================================== # Place your branching strategy here. Only the task ids that this function returns # will be executed. # It can return a single task_id as a string (e.g. `return ���������������������������������������������������������������������������������next_task_id���������������������������������������������������������������������������������`) or # a list with many task ids (e.g. `return [���������������������������������������������������������������������������������task_id_1���������������������������������������������������������������������������������, ���������������������������������������������������������������������������������task_id_2���������������������������������������������������������������������������������]`). # =================================== return "Workflow_1" return BranchPythonOperator(task_id="Branch_0", op_kwargs={"config": config}, python_callable=execute1, trigger_rule="all_success")
class BranchOperatorTest(unittest.TestCase): @classmethod def setUpClass(cls): super(BranchOperatorTest, cls).setUpClass() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() def setUp(self): self.dag = DAG('branch_operator_test', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE}, schedule_interval=INTERVAL) self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag) self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag) def tearDown(self): super().tearDown() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) self.dag.clear() self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: tis = session.query(TI).filter( TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE ) for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': # should exist with state None self.assertEqual(ti.state, State.NONE) elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.SKIPPED) else: raise Exception def test_branch_list_without_dag_run(self): """This checks if the BranchPythonOperator supports branching off to a list of tasks.""" self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: ['branch_1', 'branch_2']) self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag) self.branch_3.set_upstream(self.branch_op) self.dag.clear() self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: tis = session.query(TI).filter( TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE ) expected = { "make_choice": State.SUCCESS, "branch_1": State.NONE, "branch_2": State.NONE, "branch_3": State.SKIPPED, } for ti in tis: if ti.task_id in expected: self.assertEqual(ti.state, expected[ti.task_id]) else: raise Exception def test_with_dag_run(self): self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING ) self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': self.assertEqual(ti.state, State.NONE) elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.SKIPPED) else: raise Exception def test_with_skip_in_branch_downstream_dependencies(self): self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') self.branch_op >> self.branch_1 >> self.branch_2 self.branch_op >> self.branch_2 self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING ) self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': self.assertEqual(ti.state, State.NONE) elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.NONE) else: raise Exception def test_with_skip_in_branch_downstream_dependencies2(self): self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_2') self.branch_op >> self.branch_1 >> self.branch_2 self.branch_op >> self.branch_2 self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING ) self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': self.assertEqual(ti.state, State.SKIPPED) elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.NONE) else: raise Exception
dag=dag) t3 = PythonOperator( task_id='compare_result', provide_context=True, python_callable=compare_result, trigger_rule="all_done", dag=dag) t3.set_upstream(t1) t3.set_upstream(t2) options = ['hadoop_jar_cmd', 'presto_cmd', 'db_query', 'spark_cmd'] branching = BranchPythonOperator( task_id='branching', python_callable=lambda: random.choice(options), dag=dag) branching.set_upstream(t3) join = DummyOperator( task_id='join', trigger_rule='one_success', dag=dag ) t4 = QuboleOperator( task_id='hadoop_jar_cmd', command_type='hadoopcmd', sub_command='jar s3://paid-qubole/HadoopAPIExamples/jars/hadoop-0.20.1-dev-streaming.jar -mapper wc -numReduceTasks 0 -input s3://paid-qubole/HadoopAPITests/data/3.tsv -output s3://paid-qubole/HadoopAPITests/data/3_wc', cluster_label='default', fetch_logs=True,
'owner': 'airflow', 'start_date': seven_days_ago, } dag = DAG( dag_id='example_branch_operator', default_args=args, schedule_interval="@daily") cmd = 'ls -l' run_this_first = DummyOperator(task_id='run_this_first', dag=dag) options = ['branch_a', 'branch_b', 'branch_c', 'branch_d'] branching = BranchPythonOperator( task_id='branching', python_callable=lambda: random.choice(options), dag=dag) branching.set_upstream(run_this_first) join = DummyOperator( task_id='join', trigger_rule='one_success', dag=dag ) for option in options: t = DummyOperator(task_id=option, dag=dag) t.set_upstream(branching) dummy_follow = DummyOperator(task_id='follow_' + option, dag=dag) t.set_downstream(dummy_follow) dummy_follow.set_downstream(join)