Пример #1
0
    def test_init_with_template_cluster_label(self):
        dag = DAG(DAG_ID, start_date=DEFAULT_DATE)
        task = QuboleOperator(
            task_id=TASK_ID,
            dag=dag,
            cluster_label='{{ params.cluster_label }}',
            params={'cluster_label': 'default'},
        )

        ti = TaskInstance(task, DEFAULT_DATE)
        ti.render_templates()

        self.assertEqual(task.cluster_label, 'default')
Пример #2
0
def test_get_dated_main_runner_handles_day_shift():
    dag = DAG(dag_id='test_dag',
              start_date=datetime.strptime('2019-01-01', '%Y-%m-%d'))
    execution_date = datetime.strptime('2019-01-01',
                                       '%Y-%m-%d').replace(tzinfo=timezone.utc)
    main_func = PickleMock()
    runner = op_util.get_dated_main_runner_operator(dag,
                                                    main_func,
                                                    timedelta(minutes=1),
                                                    day_shift=1)
    ti = TaskInstance(runner, execution_date)
    ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True)
    main_func.assert_called_with('2018-12-31')
Пример #3
0
 def __init__(self, ti: TaskInstance, render_templates=True):
     self.dag_id = ti.dag_id
     self.task_id = ti.task_id
     self.task = ti.task
     self.execution_date = ti.execution_date
     self.ti = ti
     if render_templates:
         ti.render_templates()
     if os.environ.get("AIRFLOW_IS_K8S_EXECUTOR_POD", None):
         self.k8s_pod_yaml = ti.render_k8s_pod_yaml()
     self.rendered_fields = {
         field: serialize_template_field(getattr(self.task, field)) for field in self.task.template_fields
     }
Пример #4
0
    def test_localtaskjob_maintain_heart_rate(self):
        dagbag = DagBag(
            dag_folder=TEST_DAG_FOLDER,
            include_examples=False,
        )
        dag = dagbag.dags.get('test_localtaskjob_double_trigger')
        task = dag.get_task('test_localtaskjob_double_trigger_task')

        session = settings.Session()

        dag.clear()
        dag.create_dagrun(
            run_id="test",
            state=State.SUCCESS,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )

        ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti_run.refresh_from_db()
        job1 = LocalTaskJob(task_instance=ti_run,
                            executor=SequentialExecutor())

        # this should make sure we only heartbeat once and exit at the second
        # loop in _execute()
        return_codes = [None, 0]

        def multi_return_code():
            return return_codes.pop(0)

        time_start = time.time()
        with patch.object(StandardTaskRunner, 'start',
                          return_value=None) as mock_start:
            with patch.object(StandardTaskRunner,
                              'return_code') as mock_ret_code:
                mock_ret_code.side_effect = multi_return_code
                job1.run()
                assert mock_start.call_count == 1
                assert mock_ret_code.call_count == 2
        time_end = time.time()

        assert self.mock_base_job_sleep.call_count == 1
        assert job1.state == State.SUCCESS

        # Consider we have patched sleep call, it should not be sleeping to
        # keep up with the heart rate in other unpatched places
        #
        # We already make sure patched sleep call is only called once
        assert time_end - time_start < job1.heartrate
        session.close()
Пример #5
0
    def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[None, str, Iterable[str]]):
        """
        This method implements the logic for a branching operator; given a single
        task ID or list of task IDs to follow, this skips all other tasks
        immediately downstream of this operator.

        branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or
        newly added tasks should be skipped when they are cleared.
        """
        self.log.info("Following branch %s", branch_task_ids)
        if isinstance(branch_task_ids, str):
            branch_task_ids = {branch_task_ids}
        elif branch_task_ids is None:
            branch_task_ids = ()

        branch_task_ids = set(branch_task_ids)

        dag_run = ti.get_dagrun()
        task = ti.task
        dag = task.dag
        assert dag  # For Mypy.

        # At runtime, the downstream list will only be operators
        downstream_tasks = cast("List[BaseOperator]", task.downstream_list)

        if downstream_tasks:
            # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"),
            # we intuitively expect both "task1" and "join" to execute even though strictly speaking,
            # "join" is also immediately downstream of "branch" and should have been skipped. Therefore,
            # we need a special case here for such empty branches: Check downstream tasks of branch_task_ids.
            # In case the task to skip is also downstream of branch_task_ids, we add it to branch_task_ids and
            # exclude it from skipping.
            #
            # branch  ----->  join
            #   \            ^
            #     v        /
            #       task1
            #
            for branch_task_id in list(branch_task_ids):
                branch_task_ids.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False))

            skip_tasks = [t for t in downstream_tasks if t.task_id not in branch_task_ids]
            follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_ids]

            self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks])
            with create_session() as session:
                self._set_state_to_skipped(dag_run, skip_tasks, session=session)
                # For some reason, session.commit() needs to happen before xcom_push.
                # Otherwise the session is not committed.
                session.commit()
                ti.xcom_push(key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids})
Пример #6
0
def test_should_continue_with_cp(load_dag):
    dag_bag = load_dag('bq_to_wrench')
    dag = dag_bag.get_dag('bq_to_wrench')
    table = 'staging.users'
    task = dag.get_task(f'continue_if_data_{table}')
    assert isinstance(task, BranchPythonOperator)
    ti = TaskInstance(task=task, execution_date=datetime.now())
    XCom.set(key=table,
             value={'has_data': True},
             task_id=task.task_id,
             dag_id=dag.dag_id,
             execution_date=ti.execution_date)

    task.execute(ti.get_template_context())
Пример #7
0
def test_sets_initial_checkpoint(load_dag, env, bigquery_helper):
    # Remove all checkpoints for table
    table = 'staging.users'
    bigquery_helper.query(
        f"DELETE FROM `{env['project']}.system.checkpoint` WHERE table = '{table}'"
    )

    # Execute get checkpoint task. I expect it to create an initial checkpoint.
    dag_bag = load_dag('bq_to_wrench')
    dag = dag_bag.get_dag('bq_to_wrench')
    task = dag.get_task(f'get_checkpoint_{table}')
    assert isinstance(task, GetCheckpointOperator)
    ti = TaskInstance(task=task, execution_date=datetime.now())
    task.execute(ti.get_template_context())
Пример #8
0
    def test_try_adopt_task_instances(self):
        date = datetime.utcnow()
        start_date = datetime.utcnow() - timedelta(days=2)

        with DAG("test_try_adopt_task_instances"):
            task_1 = BaseOperator(task_id="task_1", start_date=start_date)
            task_2 = BaseOperator(task_id="task_2", start_date=start_date)
            task_3 = BaseOperator(task_id="task_3", start_date=start_date)

        key1 = TaskInstance(task=task_1, execution_date=date)
        key2 = TaskInstance(task=task_2, execution_date=date)
        key3 = TaskInstance(task=task_3, execution_date=date)
        tis = [key1, key2, key3]
        self.assertEqual(BaseExecutor().try_adopt_task_instances(tis), tis)
Пример #9
0
 def _run_task(self, ti: TaskInstance) -> bool:
     self.log.debug("Executing task: %s", ti)
     key = ti.key
     try:
         params = self.tasks_params.pop(ti.key, {})
         ti._run_raw_task(  # pylint: disable=protected-access
             job_id=ti.job_id,
             **params)
         self.change_state(key, State.SUCCESS)
         return True
     except Exception as e:  # pylint: disable=broad-except
         self.change_state(key, State.FAILED)
         self.log.exception("Failed to execute task: %s.", str(e))
         return False
Пример #10
0
 def __init__(self, ti: TaskInstance, render_templates=True):
     self.dag_id = ti.dag_id
     self.task_id = ti.task_id
     self.task = ti.task
     self.execution_date = ti.execution_date
     self.ti = ti
     if render_templates:
         ti.render_templates()
     if IS_K8S_OR_K8SCELERY_EXECUTOR:
         self.k8s_pod_yaml = ti.render_k8s_pod_yaml()
     self.rendered_fields = {
         field: serialize_template_field(getattr(self.task, field))
         for field in self.task.template_fields
     }
Пример #11
0
    def test_mark_success_on_success_callback(self):
        """
        Test that ensures that where a task is marked suceess in the UI
        on_success_callback gets executed
        """
        data = {'called': False}

        def success_callback(context):
            self.assertEqual(context['dag_run'].dag_id, 'test_mark_success')
            data['called'] = True

        dag = DAG(dag_id='test_mark_success',
                  start_date=DEFAULT_DATE,
                  default_args={'owner': 'owner1'})

        task = DummyOperator(task_id='test_state_succeeded1',
                             dag=dag,
                             on_success_callback=success_callback)

        session = settings.Session()

        dag.clear()
        dag.create_dagrun(
            run_id="test",
            state=State.RUNNING,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.refresh_from_db()
        job1 = LocalTaskJob(task_instance=ti,
                            ignore_ti_state=True,
                            executor=SequentialExecutor())
        from airflow.task.task_runner.standard_task_runner import StandardTaskRunner

        job1.task_runner = StandardTaskRunner(job1)
        process = multiprocessing.Process(target=job1.run)
        process.start()
        ti.refresh_from_db()
        for _ in range(0, 50):
            if ti.state == State.RUNNING:
                break
            time.sleep(0.1)
            ti.refresh_from_db()
        self.assertEqual(State.RUNNING, ti.state)
        ti.state = State.SUCCESS
        session.merge(ti)
        session.commit()

        job1.heartbeat_callback(session=None)
        self.assertTrue(data['called'])
        process.join(timeout=10)
        self.assertFalse(process.is_alive())
class PythonIdempatomicFileOperatorTest_Idempotent(unittest.TestCase):
    def f(self, output_path):
        with open(output_path, "a+") as fout:
            fout.write("test")

    def test_PyIdempaOp_idempotent(self):
        self.dag = DAG(
            TEST_DAG_ID,
            schedule_interval="@daily",
            default_args={"start_date": datetime.now()},
        )

        with TemporaryDirectory() as tempdir:
            output_path = f"{tempdir}/test_file.txt"

            self.assertFalse(
                os.path.exists(output_path))  # ensure doesn't already exist

            self.op = PythonIdempatomicFileOperator(
                dag=self.dag,
                task_id="test",
                output_pattern=output_path,
                python_callable=self.f,
            )
            self.ti = TaskInstance(task=self.op, execution_date=datetime.now())

            result = self.op.execute(self.ti.get_template_context())
            self.assertEqual(result, output_path)
            self.assertFalse(self.op.previously_completed)
            self.assertTrue(os.path.exists(output_path))

            with open(output_path, "r") as fout:
                self.assertEqual(fout.read(), "test")

            # now run task again
            result = self.op.execute(self.ti.get_template_context())
            self.assertEqual(result,
                             output_path)  # result will still give path
            self.assertTrue(self.op.previously_completed)

            # if function had run again, it would now be 'testtest'
            with open(output_path, "r") as fout:
                self.assertEqual(fout.read(), "test")

            # run function again to ensure 'testtest' is written to file upon second call
            self.f(output_path)
            with open(output_path, "r") as fout:
                self.assertEqual(fout.read(), "testtest")
class PythonIdempatomicFileOperatorTest_Atomic(unittest.TestCase):
    def g(self, output_path):
        with open(output_path, "w") as fout:
            fout.write("test")
            raise ValueError("You cannot write that!")

    def test_PyIdempaOp_atomic(self):
        self.dag = DAG(
            TEST_DAG_ID,
            schedule_interval="@daily",
            default_args={"start_date": datetime.now()},
        )

        with TemporaryDirectory() as tempdir:
            output_path = f"{tempdir}/test.txt"

            self.assertFalse(
                os.path.exists(output_path))  # ensure doesn't already exist

            self.op = PythonIdempatomicFileOperator(
                dag=self.dag,
                task_id="test",
                output_pattern=output_path,
                python_callable=self.g,
            )
            self.ti = TaskInstance(task=self.op, execution_date=datetime.now())

        with self.assertRaises(
                ValueError):  # ensure ValueError is triggered (since task ran)
            result = self.op.execute(self.ti.get_template_context())
            self.assertEqual(result, None)  # make sure no path is returned

        self.assertFalse(
            os.path.exists(output_path))  # no partially written file
    def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[str, Iterable[str]]):
        """
        This method implements the logic for a branching operator; given a single
        task ID or list of task IDs to follow, this skips all other tasks
        immediately downstream of this operator.
        """
        self.log.info("Following branch %s", branch_task_ids)
        if isinstance(branch_task_ids, str):
            branch_task_ids = [branch_task_ids]

        dag_run = ti.get_dagrun()
        task = ti.task
        dag = task.dag

        downstream_tasks = task.downstream_list

        if downstream_tasks:
            # Also check downstream tasks of the branch task. In case the task to skip
            # is also a downstream task of the branch task, we exclude it from skipping.
            branch_downstream_task_ids = set()  # type: Set[str]
            for b in branch_task_ids:
                branch_downstream_task_ids.update(dag.
                                                  get_task(b).
                                                  get_flat_relative_ids(upstream=False))

            skip_tasks = [t for t in downstream_tasks
                          if t.task_id not in branch_task_ids and
                          t.task_id not in branch_downstream_task_ids]

            self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks])
            self.skip(dag_run, ti.execution_date, skip_tasks)
Пример #15
0
    def _get_ready_tis(
        self,
        scheduleable_tasks: List[TI],
        finished_tasks: List[TI],
        session: Session,
    ) -> Tuple[List[TI], bool]:
        old_states = {}
        ready_tis: List[TI] = []
        changed_tis = False

        if not scheduleable_tasks:
            return ready_tis, changed_tis

        # Check dependencies
        for st in scheduleable_tasks:
            old_state = st.state
            if st.are_dependencies_met(dep_context=DepContext(
                    flag_upstream_failed=True, finished_tasks=finished_tasks),
                                       session=session):
                ready_tis.append(st)
            else:
                old_states[st.key] = old_state

        # Check if any ti changed state
        tis_filter = TI.filter_for_tis(old_states.keys())
        if tis_filter is not None:
            fresh_tis = session.query(TI).filter(tis_filter).all()
            changed_tis = any(ti.state != old_states[ti.key]
                              for ti in fresh_tis)

        return ready_tis, changed_tis
Пример #16
0
 def _start_task_instance(self, key: TaskInstanceKey):
     """
     Ignore all dependencies, force start a task instance
     """
     ti = self.get_task_instance(key)
     if ti is None:
         self.log.error("TaskInstance not found in DB, %s.", str(key))
         return
     command = TaskInstance.generate_command(
         ti.dag_id,
         ti.task_id,
         ti.execution_date,
         local=True,
         mark_success=False,
         ignore_all_deps=True,
         ignore_depends_on_past=True,
         ignore_task_deps=True,
         ignore_ti_state=True,
         pool=ti.pool,
         file_path=ti.dag_model.fileloc,
         pickle_id=ti.dag_model.pickle_id,
         server_uri=self._server_uri,
     )
     ti.set_state(State.QUEUED)
     self.execute_async(key=key,
                        command=command,
                        queue=ti.queue,
                        executor_config=ti.executor_config)
    def test_error_sending_task(self):
        def fake_execute_command():
            pass

        with _prepare_app(execute=fake_execute_command):
            # fake_execute_command takes no arguments while execute_command takes 1,
            # which will cause TypeError when calling task.apply_async()
            executor = celery_executor.CeleryExecutor()
            task = BashOperator(task_id="test",
                                bash_command="true",
                                dag=DAG(dag_id='id'),
                                start_date=datetime.now())
            when = datetime.now()
            value_tuple = (
                'command',
                1,
                None,
                SimpleTaskInstance(
                    ti=TaskInstance(task=task, execution_date=datetime.now())),
            )
            key = ('fail', 'fake_simple_ti', when, 0)
            executor.queued_tasks[key] = value_tuple
            executor.task_publish_retries[key] = 1
            executor.heartbeat()
        assert 0 == len(
            executor.queued_tasks), "Task should no longer be queued"
        assert executor.event_buffer[('fail', 'fake_simple_ti', when,
                                      0)][0] == State.FAILED
Пример #18
0
    def _send_stalled_tis_back_to_scheduler(
            self,
            keys: List[TaskInstanceKey],
            session: Session = NEW_SESSION) -> None:
        try:
            session.query(TaskInstance).filter(
                TaskInstance.filter_for_tis(keys),
                TaskInstance.state == State.QUEUED,
                TaskInstance.queued_by_job_id == self.job_id,
            ).update(
                {
                    TaskInstance.state: State.SCHEDULED,
                    TaskInstance.queued_dttm: None,
                    TaskInstance.queued_by_job_id: None,
                    TaskInstance.external_executor_id: None,
                },
                synchronize_session=False,
            )
            session.commit()
        except Exception:
            self.log.exception("Error sending tasks back to scheduler")
            session.rollback()
            return

        for key in keys:
            self._set_celery_pending_task_timeout(key, None)
            self.running.discard(key)
            celery_async_result = self.tasks.pop(key, None)
            if celery_async_result:
                try:
                    app.control.revoke(celery_async_result.task_id)
                except Exception as ex:
                    self.log.error(
                        "Error revoking task instance %s from celery: %s", key,
                        ex)
Пример #19
0
    def verify_integrity(self, session=None):
        """
        Verifies the DagRun by checking for removed tasks or tasks that are not in the
        database yet. It will set state to removed or add the task if required.
        """
        from airflow.models.taskinstance import TaskInstance  # Avoid circular import

        dag = self.get_dag()
        tis = self.get_task_instances(session=session)

        # check for removed or restored tasks
        task_ids = set()
        for ti in tis:
            task_instance_mutation_hook(ti)
            task_ids.add(ti.task_id)
            task = None
            try:
                task = dag.get_task(ti.task_id)
            except AirflowException:
                if ti.state == State.REMOVED:
                    pass  # ti has already been removed, just ignore it
                elif self.state is not State.RUNNING and not dag.partial:
                    self.log.warning("Failed to get task '{}' for dag '{}'. "
                                     "Marking it as removed.".format(ti, dag))
                    Stats.incr(
                        "task_removed_from_dag.{}".format(dag.dag_id), 1, 1)
                    ti.state = State.REMOVED

            is_task_in_dag = task is not None
            should_restore_task = is_task_in_dag and ti.state == State.REMOVED
            if should_restore_task:
                self.log.info("Restoring task '{}' which was previously "
                              "removed from DAG '{}'".format(ti, dag))
                Stats.incr("task_restored_to_dag.{}".format(dag.dag_id), 1, 1)
                ti.state = State.NONE
            session.merge(ti)

        # check for missing tasks
        for task in six.itervalues(dag.task_dict):
            if task.start_date > self.execution_date and not self.is_backfill:
                continue

            if task.task_id not in task_ids:
                Stats.incr(
                    "task_instance_created-{}".format(task.__class__.__name__),
                    1, 1)
                ti = TaskInstance(task, self.execution_date)
                task_instance_mutation_hook(ti)
                session.add(ti)

        try:
            session.commit()
        except IntegrityError as err:
            self.log.info(str(err))
            self.log.info(
                'Hit IntegrityError while creating the TIs for %s - %s',
                dag.dag_id, self.execution_date
            )
            self.log.info('Doing session rollback.')
            session.rollback()
Пример #20
0
    def queue_task_instance(self,
                            task_instance: TaskInstance,
                            mark_success: bool = False,
                            pickle_id: Optional[str] = None,
                            ignore_all_deps: bool = False,
                            ignore_depends_on_past: bool = False,
                            ignore_task_deps: bool = False,
                            ignore_ti_state: bool = False,
                            pool: Optional[str] = None,
                            cfg_path: Optional[str] = None) -> None:
        """Queues task instance."""
        pool = pool or task_instance.pool

        # TODO (edgarRd): AIRFLOW-1985:
        # cfg_path is needed to propagate the config values if using impersonation
        # (run_as_user), given that there are different code paths running tasks.
        # For a long term solution we need to address AIRFLOW-1986
        command_list_to_run = task_instance.command_as_list(
            local=True,
            mark_success=mark_success,
            ignore_all_deps=ignore_all_deps,
            ignore_depends_on_past=ignore_depends_on_past,
            ignore_task_deps=ignore_task_deps,
            ignore_ti_state=ignore_ti_state,
            pool=pool,
            pickle_id=pickle_id,
            cfg_path=cfg_path)
        self.queue_command(SimpleTaskInstance(task_instance),
                           command_list_to_run,
                           priority=task_instance.task.priority_weight_total,
                           queue=task_instance.task.queue)
Пример #21
0
 def create_ti_mapping(task: "Operator",
                       indexes: Tuple[int, ...]) -> Generator:
     created_counts[task.task_type] += 1
     for map_index in indexes:
         yield TI.insert_mapping(self.run_id,
                                 task,
                                 map_index=map_index)
Пример #22
0
    def _get_ready_tis(
        self,
        schedulable_tis: List[TI],
        finished_tis: List[TI],
        session: Session,
    ) -> Tuple[List[TI], bool]:
        old_states = {}
        ready_tis: List[TI] = []
        changed_tis = False

        if not schedulable_tis:
            return ready_tis, changed_tis

        # If we expand TIs, we need a new list so that we iterate over them too. (We can't alter
        # `schedulable_tis` in place and have the `for` loop pick them up
        expanded_tis: List[TI] = []
        dep_context = DepContext(
            flag_upstream_failed=True,
            ignore_unmapped_tasks=
            True,  # Ignore this Dep, as we will expand it if we can.
            finished_tis=finished_tis,
        )

        # Check dependencies
        for schedulable in itertools.chain(schedulable_tis, expanded_tis):

            old_state = schedulable.state
            if schedulable.are_dependencies_met(session=session,
                                                dep_context=dep_context):
                ready_tis.append(schedulable)
            else:
                old_states[schedulable.key] = old_state
                continue

            # Expansion of last resort! This is ideally handled in the mini-scheduler in LocalTaskJob, but if
            # for any reason it wasn't, we need to expand it now
            if schedulable.map_index < 0 and schedulable.task.is_mapped:
                # HACK. This needs a better way, one that copes with multiple upstreams!
                for ti in finished_tis:
                    if schedulable.task_id in ti.task.downstream_task_ids:

                        assert isinstance(schedulable.task, MappedOperator)
                        new_tis = schedulable.task.expand_mapped_task(
                            self.run_id, session=session)
                        if schedulable.state == TaskInstanceState.SKIPPED:
                            # Task is now skipped (likely cos upstream returned 0 tasks
                            continue
                        assert new_tis[0] is schedulable
                        expanded_tis.extend(new_tis[1:])
                        break

        # Check if any ti changed state
        tis_filter = TI.filter_for_tis(old_states.keys())
        if tis_filter is not None:
            fresh_tis = session.query(TI).filter(tis_filter).all()
            changed_tis = any(ti.state != old_states[ti.key]
                              for ti in fresh_tis)

        return ready_tis, changed_tis
Пример #23
0
    def get_link(self, operator, dttm):
        """
        Get link to qubole command result page.

        :param operator: operator
        :param dttm: datetime
        :return: url link
        """
        ti = TaskInstance(task=operator, execution_date=dttm)
        conn = BaseHook.get_connection(operator.kwargs['qubole_conn_id'])
        if conn and conn.host:
            host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host)
        else:
            host = 'https://api.qubole.com/v2/analyze?command_id='
        qds_command_id = ti.xcom_pull(task_ids=operator.task_id, key='qbol_cmd_id')
        url = host + str(qds_command_id) if qds_command_id else ''
        return url
Пример #24
0
    def skip_all_except(
        self, ti: TaskInstance, branch_task_ids: Union[str, Iterable[str]]
    ):
        """
        This method implements the logic for a branching operator; given a single
        task ID or list of task IDs to follow, this skips all other tasks
        immediately downstream of this operator.

        branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or
        newly added tasks should be skipped when they are cleared.
        """
        self.log.info("Following branch %s", branch_task_ids)
        if isinstance(branch_task_ids, str):
            branch_task_ids = [branch_task_ids]

        dag_run = ti.get_dagrun()
        task = ti.task
        dag = task.dag

        downstream_tasks = task.downstream_list

        if downstream_tasks:
            # Also check downstream tasks of the branch task. In case the task to skip
            # is also a downstream task of the branch task, we exclude it from skipping.
            branch_downstream_task_ids = set()  # type: Set[str]
            for branch_task_id in branch_task_ids:
                branch_downstream_task_ids.update(
                    dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False)
                )

            skip_tasks = [
                t
                for t in downstream_tasks
                if t.task_id not in branch_task_ids
                and t.task_id not in branch_downstream_task_ids
            ]

            self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks])
            with create_session() as session:
                self._set_state_to_skipped(
                    dag_run, ti.execution_date, skip_tasks, session=session
                )
                ti.xcom_push(
                    key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: branch_task_ids}
                )
Пример #25
0
    def _get_ready_tis(
        self,
        schedulable_tis: List[TI],
        finished_tis: List[TI],
        session: Session,
    ) -> Tuple[List[TI], bool, bool]:
        old_states = {}
        ready_tis: List[TI] = []
        changed_tis = False

        if not schedulable_tis:
            return ready_tis, changed_tis, False

        # If we expand TIs, we need a new list so that we iterate over them too. (We can't alter
        # `schedulable_tis` in place and have the `for` loop pick them up
        additional_tis: List[TI] = []
        dep_context = DepContext(
            flag_upstream_failed=True,
            ignore_unmapped_tasks=
            True,  # Ignore this Dep, as we will expand it if we can.
            finished_tis=finished_tis,
        )

        # Check dependencies.
        expansion_happened = False
        for schedulable in itertools.chain(schedulable_tis, additional_tis):
            old_state = schedulable.state
            if not schedulable.are_dependencies_met(session=session,
                                                    dep_context=dep_context):
                old_states[schedulable.key] = old_state
                continue
            # If schedulable is from a mapped task, but not yet expanded, do it
            # now. This is called in two places: First and ideally in the mini
            # scheduler at the end of LocalTaskJob, and then as an "expansion of
            # last resort" in the scheduler to ensure that the mapped task is
            # correctly expanded before executed.
            if schedulable.map_index < 0 and isinstance(
                    schedulable.task, MappedOperator):
                expanded_tis, _ = schedulable.task.expand_mapped_task(
                    self.run_id, session=session)
                if expanded_tis:
                    assert expanded_tis[0] is schedulable
                    additional_tis.extend(expanded_tis[1:])
                expansion_happened = True
            if schedulable.state in SCHEDULEABLE_STATES:
                ready_tis.append(schedulable)

        # Check if any ti changed state
        tis_filter = TI.filter_for_tis(old_states)
        if tis_filter is not None:
            fresh_tis = session.query(TI).filter(tis_filter).all()
            changed_tis = any(ti.state != old_states[ti.key]
                              for ti in fresh_tis)

        return ready_tis, changed_tis, expansion_happened
Пример #26
0
    def test_open_slots(self):
        pool = Pool(pool='test_pool', slots=5)
        dag = DAG(
            dag_id='test_open_slots',
            start_date=DEFAULT_DATE, )
        t1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool')
        t2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool')
        ti1 = TI(task=t1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=t2, execution_date=DEFAULT_DATE)
        ti1.state = State.RUNNING
        ti2.state = State.QUEUED

        session = settings.Session
        session.add(pool)
        session.add(ti1)
        session.add(ti2)
        session.commit()
        session.close()

        self.assertEqual(3, pool.open_slots())
Пример #27
0
    def test_open_slots(self):
        pool = Pool(pool='test_pool', slots=5)
        dag = DAG(
            dag_id='test_open_slots',
            start_date=DEFAULT_DATE, )
        t1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool')
        t2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool')
        ti1 = TI(task=t1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=t2, execution_date=DEFAULT_DATE)
        ti1.state = State.RUNNING
        ti2.state = State.QUEUED

        session = settings.Session
        session.add(pool)
        session.add(ti1)
        session.add(ti2)
        session.commit()
        session.close()

        self.assertEqual(3, pool.open_slots())
    def test_get_redirect_url(self):
        dag = DAG(DAG_ID, start_date=DEFAULT_DATE)

        with dag:
            task = QuboleOperator(task_id=TASK_ID,
                                  qubole_conn_id=TEST_CONN,
                                  command_type='shellcmd',
                                  parameters="param1 param2",
                                  dag=dag)

        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.xcom_push('qbol_cmd_id', 12345)

        # check for positive case
        url = task.get_extra_links(DEFAULT_DATE, 'Go to QDS')
        self.assertEqual(url, 'http://localhost/v2/analyze?command_id=12345')

        # check for negative case
        url2 = task.get_extra_links(datetime(2017, 1, 2), 'Go to QDS')
        self.assertEqual(url2, '')
Пример #29
0
    def test_localtaskjob_double_trigger(self):
        dagbag = DagBag(
            dag_folder=TEST_DAG_FOLDER,
            include_examples=False,
        )
        dag = dagbag.dags.get('test_localtaskjob_double_trigger')
        task = dag.get_task('test_localtaskjob_double_trigger_task')

        session = settings.Session()

        dag.clear()
        dr = dag.create_dagrun(
            run_id="test",
            state=State.SUCCESS,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )
        ti = dr.get_task_instance(task_id=task.task_id, session=session)
        ti.state = State.RUNNING
        ti.hostname = get_hostname()
        ti.pid = 1
        session.merge(ti)
        session.commit()

        ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti_run.refresh_from_db()
        job1 = LocalTaskJob(task_instance=ti_run,
                            executor=SequentialExecutor())
        from airflow.task.task_runner.standard_task_runner import StandardTaskRunner

        with patch.object(StandardTaskRunner, 'start',
                          return_value=None) as mock_method:
            job1.run()
            mock_method.assert_not_called()

        ti = dr.get_task_instance(task_id=task.task_id, session=session)
        self.assertEqual(ti.pid, 1)
        self.assertEqual(ti.state, State.RUNNING)

        session.close()
    def skip(self, dag_run, execution_date, tasks, session=None):
        """
        Sets tasks instances to skipped from the same dag run.

        :param dag_run: the DagRun for which to set the tasks to skipped
        :param execution_date: execution_date
        :param tasks: tasks to skip (not task_ids)
        :param session: db session to use
        """
        if not tasks:
            return

        task_ids = [d.task_id for d in tasks]
        now = timezone.utcnow()

        if dag_run:
            session.query(TaskInstance).filter(
                TaskInstance.dag_id == dag_run.dag_id,
                TaskInstance.execution_date == dag_run.execution_date,
                TaskInstance.task_id.in_(task_ids)).update(
                    {
                        TaskInstance.state: State.SKIPPED,
                        TaskInstance.start_date: now,
                        TaskInstance.end_date: now
                    },
                    synchronize_session=False)
            session.commit()
        else:
            if execution_date is None:
                raise ValueError("Execution date is None and no dag run")

            self.log.warning("No DAG RUN present this should not happen")
            # this is defensive against dag runs that are not complete
            for task in tasks:
                ti = TaskInstance(task, execution_date=execution_date)
                ti.state = State.SKIPPED
                ti.start_date = now
                ti.end_date = now
                session.merge(ti)

            session.commit()
Пример #31
0
    def test_open_slots(self):
        pool = Pool(pool='test_pool', slots=5)
        dag = DAG(
            dag_id='test_open_slots',
            start_date=DEFAULT_DATE,
        )
        op1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool')
        op2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool')
        ti1 = TI(task=op1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=op2, execution_date=DEFAULT_DATE)
        ti1.state = State.RUNNING
        ti2.state = State.QUEUED

        session = settings.Session
        session.add(pool)
        session.add(ti1)
        session.add(ti2)
        session.commit()
        session.close()

        self.assertEqual(3, pool.open_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(1, pool.running_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(1, pool.queued_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(2, pool.occupied_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(
            {
                "default_pool": {
                    "open": 128,
                    "queued": 0,
                    "total": 128,
                    "running": 0,
                },
                "test_pool": {
                    "open": 3,
                    "queued": 1,
                    "running": 1,
                    "total": 5,
                },
            },
            pool.slots_stats(),
        )
Пример #32
0
    def skip(self, dag_run, execution_date, tasks, session=None):
        """
        Sets tasks instances to skipped from the same dag run.

        :param dag_run: the DagRun for which to set the tasks to skipped
        :param execution_date: execution_date
        :param tasks: tasks to skip (not task_ids)
        :param session: db session to use
        """
        if not tasks:
            return

        task_ids = [d.task_id for d in tasks]
        now = timezone.utcnow()

        if dag_run:
            session.query(TaskInstance).filter(
                TaskInstance.dag_id == dag_run.dag_id,
                TaskInstance.execution_date == dag_run.execution_date,
                TaskInstance.task_id.in_(task_ids)
            ).update({TaskInstance.state: State.SKIPPED,
                      TaskInstance.start_date: now,
                      TaskInstance.end_date: now},
                     synchronize_session=False)
            session.commit()
        else:
            assert execution_date is not None, "Execution date is None and no dag run"

            self.log.warning("No DAG RUN present this should not happen")
            # this is defensive against dag runs that are not complete
            for task in tasks:
                ti = TaskInstance(task, execution_date=execution_date)
                ti.state = State.SKIPPED
                ti.start_date = now
                ti.end_date = now
                session.merge(ti)

            session.commit()
Пример #33
0
 def get_link(self, operator, dttm):
     ti = TaskInstance(task=operator, execution_date=dttm)
     job_id = ti.xcom_pull(task_ids=operator.task_id, key='job_id')
     return 'https://console.cloud.google.com/bigquery?j={job_id}'.format(
         job_id=job_id) if job_id else ''