Beispiel #1
0
    def test_dagrun_success_when_all_skipped(self):
        """
        Tests that a DAG run succeeds when all tasks are skipped
        """
        dag = DAG(dag_id='test_dagrun_success_when_all_skipped',
                  start_date=timezone.datetime(2017, 1, 1))
        dag_task1 = ShortCircuitOperator(task_id='test_short_circuit_false',
                                         dag=dag,
                                         python_callable=lambda: False)
        dag_task2 = DummyOperator(task_id='test_state_skipped1', dag=dag)
        dag_task3 = DummyOperator(task_id='test_state_skipped2', dag=dag)
        dag_task1.set_downstream(dag_task2)
        dag_task2.set_downstream(dag_task3)

        initial_task_states = {
            'test_short_circuit_false': State.SUCCESS,
            'test_state_skipped1': State.SKIPPED,
            'test_state_skipped2': State.SKIPPED,
        }

        dag_run = self.create_dag_run(dag=dag,
                                      state=State.RUNNING,
                                      task_states=initial_task_states)
        dag_run.update_state()
        self.assertEqual(State.SUCCESS, dag_run.state)
Beispiel #2
0
    def test_clear_skipped_downstream_task(self):
        """
        After a downstream task is skipped by ShortCircuitOperator, clearing the skipped task
        should not cause it to be executed.
        """
        dag = DAG(
            'shortcircuit_clear_skipped_downstream_task',
            default_args={
                'owner': 'airflow',
                'start_date': DEFAULT_DATE
            },
            schedule_interval=INTERVAL,
        )
        short_op = ShortCircuitOperator(task_id='make_choice',
                                        dag=dag,
                                        python_callable=lambda: False)
        downstream = DummyOperator(task_id='downstream', dag=dag)

        short_op >> downstream

        dag.clear()

        dr = dag.create_dagrun(
            run_type=DagRunType.MANUAL,
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
        downstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()

        for ti in tis:
            if ti.task_id == 'make_choice':
                assert ti.state == State.SUCCESS
            elif ti.task_id == 'downstream':
                assert ti.state == State.SKIPPED
            else:
                raise ValueError(f'Invalid task id {ti.task_id} found!')

        # Clear downstream
        with create_session() as session:
            clear_task_instances([t for t in tis if t.task_id == "downstream"],
                                 session=session,
                                 dag=dag)

        # Run downstream again
        downstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        # Check if the states are correct.
        for ti in dr.get_task_instances():
            if ti.task_id == 'make_choice':
                assert ti.state == State.SUCCESS
            elif ti.task_id == 'downstream':
                assert ti.state == State.SKIPPED
            else:
                raise ValueError(f'Invalid task id {ti.task_id} found!')
Beispiel #3
0
    def test_get_task_instance_on_empty_dagrun(self):
        """
        Make sure that a proper value is returned when a dagrun has no task instances
        """
        dag = DAG(dag_id='test_get_task_instance_on_empty_dagrun', start_date=timezone.datetime(2017, 1, 1))
        ShortCircuitOperator(task_id='test_short_circuit_false', dag=dag, python_callable=lambda: False)

        session = settings.Session()

        now = timezone.utcnow()

        # Don't use create_dagrun since it will create the task instances too which we
        # don't want
        dag_run = models.DagRun(
            dag_id=dag.dag_id,
            run_type=DagRunType.MANUAL,
            execution_date=now,
            start_date=now,
            state=State.RUNNING,
            external_trigger=False,
        )
        session.add(dag_run)
        session.commit()

        ti = dag_run.get_task_instance('test_short_circuit_false')
        assert ti is None
Beispiel #4
0
    def test_with_dag_run(self):
        value = False
        dag = DAG('shortcircuit_operator_test_with_dag_run',
                  default_args={
                      'owner': 'airflow',
                      'start_date': DEFAULT_DATE
                  },
                  schedule_interval=INTERVAL)
        short_op = ShortCircuitOperator(task_id='make_choice',
                                        dag=dag,
                                        python_callable=lambda: value)
        branch_1 = DummyOperator(task_id='branch_1', dag=dag)
        branch_1.set_upstream(short_op)
        branch_2 = DummyOperator(task_id='branch_2', dag=dag)
        branch_2.set_upstream(branch_1)
        upstream = DummyOperator(task_id='upstream', dag=dag)
        upstream.set_downstream(short_op)
        dag.clear()

        logging.error("Tasks %s", dag.tasks)
        dr = dag.create_dagrun(
            run_type=DagRunType.MANUAL,
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING
        )

        upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
        short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 4)
        for ti in tis:
            if ti.task_id == 'make_choice':
                self.assertEqual(ti.state, State.SUCCESS)
            elif ti.task_id == 'upstream':
                self.assertEqual(ti.state, State.SUCCESS)
            elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
                self.assertEqual(ti.state, State.SKIPPED)
            else:
                raise ValueError(f'Invalid task id {ti.task_id} found!')

        value = True
        dag.clear()
        dr.verify_integrity()
        upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
        short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 4)
        for ti in tis:
            if ti.task_id == 'make_choice':
                self.assertEqual(ti.state, State.SUCCESS)
            elif ti.task_id == 'upstream':
                self.assertEqual(ti.state, State.SUCCESS)
            elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
                self.assertEqual(ti.state, State.NONE)
            else:
                raise ValueError(f'Invalid task id {ti.task_id} found!')
Beispiel #5
0
    def test_without_dag_run(self):
        """This checks the defensive against non existent tasks in a dag run"""
        value = False
        dag = DAG(
            'shortcircuit_operator_test_without_dag_run',
            default_args={
                'owner': 'airflow',
                'start_date': DEFAULT_DATE
            },
            schedule_interval=INTERVAL,
        )
        short_op = ShortCircuitOperator(task_id='make_choice',
                                        dag=dag,
                                        python_callable=lambda: value)
        branch_1 = DummyOperator(task_id='branch_1', dag=dag)
        branch_1.set_upstream(short_op)
        branch_2 = DummyOperator(task_id='branch_2', dag=dag)
        branch_2.set_upstream(branch_1)
        upstream = DummyOperator(task_id='upstream', dag=dag)
        upstream.set_downstream(short_op)
        dag.clear()

        short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        with create_session() as session:
            tis = session.query(TI).filter(TI.dag_id == dag.dag_id,
                                           TI.execution_date == DEFAULT_DATE)

            for ti in tis:
                if ti.task_id == 'make_choice':
                    assert ti.state == State.SUCCESS
                elif ti.task_id == 'upstream':
                    # should not exist
                    raise ValueError(f'Invalid task id {ti.task_id} found!')
                elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
                    assert ti.state == State.SKIPPED
                else:
                    raise ValueError(f'Invalid task id {ti.task_id} found!')

            value = True
            dag.clear()

            short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
            for ti in tis:
                if ti.task_id == 'make_choice':
                    assert ti.state == State.SUCCESS
                elif ti.task_id == 'upstream':
                    # should not exist
                    raise ValueError(f'Invalid task id {ti.task_id} found!')
                elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
                    assert ti.state == State.NONE
                else:
                    raise ValueError(f'Invalid task id {ti.task_id} found!')
Beispiel #6
0
from airflow.utils import dates

args = {
    'owner': 'airflow',
}

dag = DAG(
    dag_id='example_short_circuit_operator',
    default_args=args,
    start_date=dates.days_ago(2),
    tags=['example'],
)

cond_true = ShortCircuitOperator(
    task_id='condition_is_True',
    python_callable=lambda: True,
    dag=dag,
)

cond_false = ShortCircuitOperator(
    task_id='condition_is_False',
    python_callable=lambda: False,
    dag=dag,
)

ds_true = [DummyOperator(task_id='true_' + str(i), dag=dag) for i in [1, 2]]
ds_false = [DummyOperator(task_id='false_' + str(i), dag=dag) for i in [1, 2]]

chain(cond_true, *ds_true)
chain(cond_false, *ds_false)
locate_file_cmd = """
    sleep 10
    find {{params.source_location}} -type f  -printf "%f\n" | head -1
"""

t_view = BashOperator(
    task_id="view_file",
    bash_command=locate_file_cmd,
    do_xcom_push=True,
    params={"source_location": "/your/input_dir/path"},
    dag=dag,
)

t_is_data_available = ShortCircuitOperator(
    task_id="check_if_data_available",
    python_callable=lambda task_output: not task_output == "",
    op_kwargs=dict(task_output=t_view.output),
    dag=dag,
)

t_move = DockerOperator(
    api_version="1.19",
    docker_url="tcp://localhost:2375",  # replace it with swarm/docker endpoint
    image="centos:latest",
    network_mode="bridge",
    mounts=[
        Mount(source="/your/host/input_dir/path",
              target="/your/input_dir/path",
              type="bind"),
        Mount(source="/your/host/output_dir/path",
              target="/your/output_dir/path",
              type="bind"),
Beispiel #8
0
        task_id=f'export_missing_premsa_{gene}',
        python_callable=export_sequences,
        op_kwargs={ "gene" : gene, "output_fn" : filepath },
        dag=dag,
    )

    pre_msa = BashOperator(
        task_id=f'pre_msa_{gene}',
        bash_command=PREMSA,
        params={'regions': regions, 'filepath': filepath, 'gene': gene, 'node' : i % 8, 'stdout' : stdout },
        dag=dag,
    )

    populated_check_task = ShortCircuitOperator(
        task_id=f'check_if_populated_{gene}',
        python_callable=is_export_populated,
        op_kwargs={ 'filepath': filepath },
        dag=dag
    )

    # Store nuc_input, prot_input, type
    import_premsa_seqs = PythonOperator(
        task_id=f'store_premsa_{gene}',
        python_callable=store_premsa_file,
        op_kwargs={ "nuc_input" : nuc_input_filepath, "prot_input" : prot_input_filepath, "gene": gene },
        dag=dag,
    )

    mark_troubled_task = PythonOperator(
        task_id=f'mark_troubled_{gene}',
        python_callable=mark_troubled,
        op_kwargs={ "log_file" : stdout, "gene": gene },
Beispiel #9
0
from airflow import DAG
from airflow.models.baseoperator import chain
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import ShortCircuitOperator
from airflow.utils.trigger_rule import TriggerRule

with DAG(
        dag_id='example_short_circuit_operator',
        start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
        catchup=False,
        tags=['example'],
) as dag:
    # [START howto_operator_short_circuit]
    cond_true = ShortCircuitOperator(
        task_id='condition_is_True',
        python_callable=lambda: True,
    )

    cond_false = ShortCircuitOperator(
        task_id='condition_is_False',
        python_callable=lambda: False,
    )

    ds_true = [EmptyOperator(task_id='true_' + str(i)) for i in [1, 2]]
    ds_false = [EmptyOperator(task_id='false_' + str(i)) for i in [1, 2]]

    chain(cond_true, *ds_true)
    chain(cond_false, *ds_false)
    # [END howto_operator_short_circuit]

    # [START howto_operator_short_circuit_trigger_rules]
Beispiel #10
0
    bash_command=locate_file_cmd,
    do_xcom_push=True,
    params={"source_location": "/your/input_dir/path"},
    dag=dag,
)


def is_data_available(*args, **kwargs):
    """Return True if data exists in XCom table for view_file task, false otherwise."""
    ti = kwargs["ti"]
    data = ti.xcom_pull(key=None, task_ids="view_file")
    return not data == ""


t_is_data_available = ShortCircuitOperator(
    task_id="check_if_data_available", python_callable=is_data_available, dag=dag
)

t_move = DockerOperator(
    api_version="1.19",
    docker_url="tcp://localhost:2375",  # replace it with swarm/docker endpoint
    image="centos:latest",
    network_mode="bridge",
    volumes=[
        "/your/host/input_dir/path:/your/input_dir/path",
        "/your/host/output_dir/path:/your/output_dir/path",
    ],
    command=[
        "/bin/bash",
        "-c",
        "/bin/sleep 30; "