Exemple #1
0
    def test_reduce_in_chunks(self):
        assert helpers.reduce_in_chunks(lambda x, y: x + [y], [1, 2, 3, 4, 5],
                                        []) == [[1, 2, 3, 4, 5]]

        assert helpers.reduce_in_chunks(lambda x, y: x + [y], [1, 2, 3, 4, 5],
                                        [], 2) == [[1, 2], [3, 4], [5]]

        assert helpers.reduce_in_chunks(lambda x, y: x + y[0] * y[1],
                                        [1, 2, 3, 4], 0, 2) == 14
Exemple #2
0
    def test_reduce_in_chunks(self):
        self.assertEqual(
            helpers.reduce_in_chunks(lambda x, y: x + [y], [1, 2, 3, 4, 5], []), [[1, 2, 3, 4, 5]]
        )

        self.assertEqual(
            helpers.reduce_in_chunks(lambda x, y: x + [y], [1, 2, 3, 4, 5], [], 2), [[1, 2], [3, 4], [5]]
        )

        self.assertEqual(helpers.reduce_in_chunks(lambda x, y: x + y[0] * y[1], [1, 2, 3, 4], 0, 2), 14)
    def _update_ti_hostname(self, sensor_works, session=None):
        """
        Update task instance hostname for new sensor works.

        :param sensor_works: Smart sensor internal object for a sensor task.
        :param session: The sqlalchemy session.
        """
        TI = TaskInstance
        ti_keys = [(x.dag_id, x.task_id, x.execution_date) for x in sensor_works]

        def update_ti_hostname_with_count(count, ti_keys):
            # Using or_ instead of in_ here to prevent from full table scan.
            tis = session.query(TI) \
                .filter(or_(tuple_(TI.dag_id, TI.task_id, TI.execution_date) == ti_key
                            for ti_key in ti_keys)) \
                .all()

            for ti in tis:
                ti.hostname = self.hostname
            session.commit()

            return count + len(ti_keys)

        count = helpers.reduce_in_chunks(update_ti_hostname_with_count, ti_keys, 0, self.max_tis_per_query)
        if count:
            self.log.info("Updated hostname on %s tis.", count)
Exemple #4
0
    def reset_state_for_orphaned_tasks(self, filter_by_dag_run=None, session=None):
        """
        This function checks if there are any tasks in the dagrun (or all) that
        have a schedule or queued states but are not known by the executor. If
        it finds those it will reset the state to None so they will get picked
        up again.  The batch option is for performance reasons as the queries
        are made in sequence.

        :param filter_by_dag_run: the dag_run we want to process, None if all
        :return: the number of TIs reset
        :rtype: int
        """
        queued_tis = self.executor.queued_tasks
        # also consider running as the state might not have changed in the db yet
        running_tis = self.executor.running

        # Can't use an update here since it doesn't support joins.
        resettable_states = [TaskInstanceState.SCHEDULED, TaskInstanceState.QUEUED]
        if filter_by_dag_run is None:
            resettable_tis = (
                session.query(TaskInstance)
                .join(TaskInstance.dag_run)
                .filter(
                    DagRun.state == DagRunState.RUNNING,
                    DagRun.run_type != DagRunType.BACKFILL_JOB,
                    TaskInstance.state.in_(resettable_states),
                )
            ).all()
        else:
            resettable_tis = filter_by_dag_run.get_task_instances(state=resettable_states, session=session)

        tis_to_reset = [ti for ti in resettable_tis if ti.key not in queued_tis and ti.key not in running_tis]
        if not tis_to_reset:
            return 0

        def query(result, items):
            if not items:
                return result

            filter_for_tis = TaskInstance.filter_for_tis(items)
            reset_tis = (
                session.query(TaskInstance)
                .filter(filter_for_tis, TaskInstance.state.in_(resettable_states))
                .with_for_update()
                .all()
            )

            for ti in reset_tis:
                ti.state = State.NONE
                session.merge(ti)

            return result + reset_tis

        reset_tis = helpers.reduce_in_chunks(query, tis_to_reset, [], self.max_tis_per_query)

        task_instance_str = '\n\t'.join(repr(x) for x in reset_tis)
        session.flush()

        self.log.info("Reset the following %s TaskInstances:\n\t%s", len(reset_tis), task_instance_str)
        return len(reset_tis)
    def test_reduce_in_chunks(self):
        self.assertEqual(helpers.reduce_in_chunks(lambda x, y: x + [y],
                                                  [1, 2, 3, 4, 5],
                                                  []),
                         [[1, 2, 3, 4, 5]])

        self.assertEqual(helpers.reduce_in_chunks(lambda x, y: x + [y],
                                                  [1, 2, 3, 4, 5],
                                                  [],
                                                  2),
                         [[1, 2], [3, 4], [5]])

        self.assertEqual(helpers.reduce_in_chunks(lambda x, y: x + y[0] * y[1],
                                                  [1, 2, 3, 4],
                                                  0,
                                                  2),
                         14)
    def _update_ti_hostname(self, sensor_works, session=None):
        """
        Update task instance hostname for new sensor works.

        :param sensor_works: Smart sensor internal object for a sensor task.
        :param session: The sqlalchemy session.
        """
        DR = DagRun
        TI = TaskInstance

        def update_ti_hostname_with_count(count, sensor_works):
            # Using or_ instead of in_ here to prevent from full table scan.
            if session.bind.dialect.name == 'mssql':
                ti_filter = or_(
                    and_(
                        TI.dag_id == ti_key.dag_id,
                        TI.task_id == ti_key.task_id,
                        DR.execution_date == ti_key.execution_date,
                    ) for ti_key in sensor_works)
            else:
                ti_keys = [(x.dag_id, x.task_id, x.execution_date)
                           for x in sensor_works]
                ti_filter = or_(
                    tuple_(TI.dag_id, TI.task_id, DR.execution_date) == ti_key
                    for ti_key in ti_keys)

            for ti in session.query(TI).join(TI.dag_run).filter(ti_filter):
                ti.hostname = self.hostname
            session.commit()

            return count + len(sensor_works)

        count = helpers.reduce_in_chunks(update_ti_hostname_with_count,
                                         sensor_works, 0,
                                         self.max_tis_per_query)
        if count:
            self.log.info("Updated hostname on %s tis.", count)
Exemple #7
0
    def reset_state_for_orphaned_tasks(self,
                                       filter_by_dag_run=None,
                                       session=None):
        """
        This function checks if there are any tasks in the dagrun (or all)
        that have a scheduled state but are not known by the
        executor. If it finds those it will reset the state to None
        so they will get picked up again.
        The batch option is for performance reasons as the queries are made in
        sequence.

        :param filter_by_dag_run: the dag_run we want to process, None if all
        :type filter_by_dag_run: airflow.models.DagRun
        :return: the TIs reset (in expired SQLAlchemy state)
        :rtype: list[airflow.models.TaskInstance]
        """
        queued_tis = self.executor.queued_tasks
        # also consider running as the state might not have changed in the db yet
        running_tis = self.executor.running

        resettable_states = [State.SCHEDULED, State.QUEUED]
        TI = models.TaskInstance
        DR = models.DagRun
        if filter_by_dag_run is None:
            resettable_tis = (
                session.query(TI).join(
                    DR,
                    and_(TI.dag_id == DR.dag_id,
                         TI.execution_date == DR.execution_date)).filter(
                             # pylint: disable=comparison-with-callable
                             DR.state == State.RUNNING,
                             DR.run_id.notlike(
                                 f"{DagRunType.BACKFILL_JOB.value}__%"),
                             TI.state.in_(resettable_states))).all()
        else:
            resettable_tis = filter_by_dag_run.get_task_instances(
                state=resettable_states, session=session)
        tis_to_reset = []
        # Can't use an update here since it doesn't support joins
        for ti in resettable_tis:
            if ti.key not in queued_tis and ti.key not in running_tis:
                tis_to_reset.append(ti)

        if len(tis_to_reset) == 0:
            return []

        def query(result, items):
            if not items:
                return result

            filter_for_tis = TI.filter_for_tis(items)
            reset_tis = session.query(TI).filter(
                filter_for_tis,
                TI.state.in_(resettable_states)).with_for_update().all()

            for ti in reset_tis:
                ti.state = State.NONE
                session.merge(ti)

            return result + reset_tis

        reset_tis = helpers.reduce_in_chunks(query, tis_to_reset, [],
                                             self.max_tis_per_query)

        task_instance_str = '\n\t'.join([repr(x) for x in reset_tis])
        session.commit()

        self.log.info("Reset the following %s TaskInstances:\n\t%s",
                      len(reset_tis), task_instance_str)
        return reset_tis
Exemple #8
0
def reset_state_for_orphaned_tasks(single_dag_run_job: BaseJob,
                                   filter_by_dag_run=None,
                                   session=None):
    """
    It was removed from Airflow 2.x Worth looking why it happened in Airflow github
    For now just replicated behaviour of v1. Probably not needed anymore

    This function checks if there are any tasks in the dagrun (or all)
    that have a scheduled state but are not known by the
    executor. If it finds those it will reset the state to None
    so they will get picked up again.
    The batch option is for performance reasons as the queries are made in
    sequence.

    :param filter_by_dag_run: the dag_run we want to process, None if all
    :type filter_by_dag_run: airflow.models.DagRun
    :return: the TIs reset (in expired SQLAlchemy state)
    :rtype: list[airflow.models.TaskInstance]
    """
    from airflow.jobs.backfill_job import BackfillJob

    queued_tis = single_dag_run_job.executor.queued_tasks
    # also consider running as the state might not have changed in the db yet
    running_tis = single_dag_run_job.executor.running

    resettable_states = [State.SCHEDULED, State.QUEUED]
    TI = models.TaskInstance
    DR = models.DagRun
    if filter_by_dag_run is None:
        resettable_tis = (session.query(TI).join(
            DR,
            and_(TI.dag_id == DR.dag_id,
                 TI.execution_date == DR.execution_date)).filter(
                     DR.state == State.RUNNING,
                     DR.run_id.notlike(BackfillJob.ID_PREFIX + "%"),
                     TI.state.in_(resettable_states),
                 )).all()
    else:
        resettable_tis = filter_by_dag_run.get_task_instances(
            state=resettable_states, session=session)
    tis_to_reset = []
    # Can't use an update here since it doesn't support joins
    for ti in resettable_tis:
        if ti.key not in queued_tis and ti.key not in running_tis:
            tis_to_reset.append(ti)

    if len(tis_to_reset) == 0:
        return []

    def query(result, items):
        filter_for_tis = [
            and_(
                TI.dag_id == ti.dag_id,
                TI.task_id == ti.task_id,
                TI.execution_date == ti.execution_date,
            ) for ti in items
        ]
        reset_tis = (session.query(TI).filter(
            or_(*filter_for_tis),
            TI.state.in_(resettable_states)).with_for_update().all())
        for ti in reset_tis:
            ti.state = State.NONE
            session.merge(ti)
        return result + reset_tis

    reset_tis = helpers.reduce_in_chunks(query, tis_to_reset, [],
                                         single_dag_run_job.max_tis_per_query)

    task_instance_str = "\n\t".join([repr(x) for x in reset_tis])
    session.commit()

    single_dag_run_job.log.info("Reset the following %s TaskInstances:\n\t%s",
                                len(reset_tis), task_instance_str)
    return reset_tis