Exemplo n.º 1
0
    def clear(
        cls,
        execution_date: Optional[pendulum.DateTime] = None,
        dag_id: Optional[str] = None,
        task_id: Optional[str] = None,
        session: Session = NEW_SESSION,
        *,
        run_id: Optional[str] = None,
    ) -> None:
        """:sphinx-autoapi-skip:"""
        from airflow.models import DagRun

        # Given the historic order of this function (execution_date was first argument) to add a new optional
        # param we need to add default values for everything :(
        if dag_id is None:
            raise TypeError("clear() missing required argument: dag_id")
        if task_id is None:
            raise TypeError("clear() missing required argument: task_id")

        if not exactly_one(execution_date is not None, run_id is not None):
            raise ValueError(
                "Exactly one of run_id or execution_date must be passed")

        if execution_date is not None:
            message = "Passing 'execution_date' to 'XCom.clear()' is deprecated. Use 'run_id' instead."
            warnings.warn(message, DeprecationWarning, stacklevel=3)
            run_id = (session.query(DagRun.run_id).filter(
                DagRun.dag_id == dag_id,
                DagRun.execution_date == execution_date).scalar())

        return session.query(cls).filter_by(dag_id=dag_id,
                                            task_id=task_id,
                                            run_id=run_id).delete()
Exemplo n.º 2
0
    def get_many(
        cls,
        execution_date: Optional[datetime.datetime] = None,
        key: Optional[str] = None,
        task_ids: Optional[Union[str, Iterable[str]]] = None,
        dag_ids: Optional[Union[str, Iterable[str]]] = None,
        map_indexes: Union[int, Iterable[int], None] = None,
        include_prior_dates: bool = False,
        limit: Optional[int] = None,
        session: Session = NEW_SESSION,
        *,
        run_id: Optional[str] = None,
    ) -> Query:
        """:sphinx-autoapi-skip:"""
        from airflow.models.dagrun import DagRun

        if not exactly_one(execution_date is not None, run_id is not None):
            raise ValueError(
                f"Exactly one of run_id or execution_date must be passed. "
                f"Passed execution_date={execution_date}, run_id={run_id}"
            )
        if execution_date is not None:
            message = "Passing 'execution_date' to 'XCom.get_many()' is deprecated. Use 'run_id' instead."
            warnings.warn(message, PendingDeprecationWarning, stacklevel=3)

        query = session.query(cls).join(cls.dag_run)

        if key:
            query = query.filter(cls.key == key)

        if is_container(task_ids):
            query = query.filter(cls.task_id.in_(task_ids))
        elif task_ids is not None:
            query = query.filter(cls.task_id == task_ids)

        if is_container(dag_ids):
            query = query.filter(cls.dag_id.in_(dag_ids))
        elif dag_ids is not None:
            query = query.filter(cls.dag_id == dag_ids)

        if is_container(map_indexes):
            query = query.filter(cls.map_index.in_(map_indexes))
        elif map_indexes is not None:
            query = query.filter(cls.map_index == map_indexes)

        if include_prior_dates:
            if execution_date is not None:
                query = query.filter(DagRun.execution_date <= execution_date)
            else:
                dr = session.query(DagRun.execution_date).filter(DagRun.run_id == run_id).subquery()
                query = query.filter(cls.execution_date <= dr.c.execution_date)
        elif execution_date is not None:
            query = query.filter(DagRun.execution_date == execution_date)
        else:
            query = query.filter(cls.run_id == run_id)

        query = query.order_by(DagRun.execution_date.desc(), cls.timestamp.desc())
        if limit:
            return query.limit(limit)
        return query
Exemplo n.º 3
0
    def get_one(
        cls,
        execution_date: Optional[datetime.datetime] = None,
        key: Optional[str] = None,
        task_id: Optional[str] = None,
        dag_id: Optional[str] = None,
        include_prior_dates: bool = False,
        session: Session = NEW_SESSION,
        *,
        run_id: Optional[str] = None,
        ti_key: Optional["TaskInstanceKey"] = None,
    ) -> Optional[Any]:
        """:sphinx-autoapi-skip:"""
        if not exactly_one(execution_date is not None, ti_key is not None,
                           run_id is not None):
            raise ValueError(
                "Exactly one of ti_key, run_id, or execution_date must be passed"
            )

        if ti_key is not None:
            query = session.query(cls).filter_by(
                dag_id=ti_key.dag_id,
                run_id=ti_key.run_id,
                task_id=ti_key.task_id,
            )
            if key:
                query = query.filter_by(key=key)
            query = query.limit(1)
        elif run_id:
            query = cls.get_many(
                run_id=run_id,
                key=key,
                task_ids=task_id,
                dag_ids=dag_id,
                include_prior_dates=include_prior_dates,
                limit=1,
                session=session,
            )
        elif execution_date is not None:
            message = "Passing 'execution_date' to 'XCom.get_one()' is deprecated. Use 'run_id' instead."
            warnings.warn(message, PendingDeprecationWarning, stacklevel=3)

            with warnings.catch_warnings():
                warnings.simplefilter("ignore", DeprecationWarning)
                query = cls.get_many(
                    execution_date=execution_date,
                    key=key,
                    task_ids=task_id,
                    dag_ids=dag_id,
                    include_prior_dates=include_prior_dates,
                    limit=1,
                    session=session,
                )
        else:
            raise RuntimeError("Should not happen?")

        result = query.with_entities(cls.value).first()
        if result:
            return cls.deserialize_value(result)
        return None
Exemplo n.º 4
0
    def get_many(
        cls,
        execution_date: Optional[pendulum.DateTime] = None,
        key: Optional[str] = None,
        task_ids: Optional[Union[str, Iterable[str]]] = None,
        dag_ids: Optional[Union[str, Iterable[str]]] = None,
        include_prior_dates: bool = False,
        limit: Optional[int] = None,
        session: Session = NEW_SESSION,
        *,
        run_id: Optional[str] = None,
    ) -> Query:
        """:sphinx-autoapi-skip:"""
        from airflow.models.dagrun import DagRun

        if not exactly_one(execution_date is not None, run_id is not None):
            raise ValueError(
                "Exactly one of execution_date or run_id must be passed")
        if execution_date is not None:
            message = "Passing 'execution_date' to 'XCom.get_many()' is deprecated. Use 'run_id' instead."
            warnings.warn(message, PendingDeprecationWarning, stacklevel=3)

        query = session.query(cls)

        if key:
            query = query.filter(cls.key == key)

        if is_container(task_ids):
            query = query.filter(cls.task_id.in_(task_ids))
        elif task_ids is not None:
            query = query.filter(cls.task_id == task_ids)

        if is_container(dag_ids):
            query = query.filter(cls.dag_id.in_(dag_ids))
        elif dag_ids is not None:
            query = query.filter(cls.dag_id == dag_ids)

        if include_prior_dates:
            if execution_date is not None:
                query = query.filter(cls.execution_date <= execution_date)
            else:
                # This returns an empty query result for IN_MEMORY_DAGRUN_ID,
                # but that is impossible to implement. Sorry?
                dr = session.query(DagRun.execution_date).filter(
                    DagRun.run_id == run_id).subquery()
                query = query.filter(cls.execution_date <= dr.c.execution_date)
        elif execution_date is not None:
            query = query.filter(cls.execution_date == execution_date)
        elif run_id == IN_MEMORY_DAGRUN_ID:
            query = query.filter(cls.execution_date == _DISTANT_FUTURE)
        else:
            query = query.join(cls.dag_run).filter(DagRun.run_id == run_id)

        query = query.order_by(cls.execution_date.desc(), cls.timestamp.desc())
        if limit:
            return query.limit(limit)
        return query
Exemplo n.º 5
0
def set_dag_run_state_to_success(
    *,
    dag: DAG,
    execution_date: Optional[datetime] = None,
    run_id: Optional[str] = None,
    commit: bool = False,
    session: SASession = NEW_SESSION,
):
    """
    Set the dag run for a specific execution date and its task instances
    to success.

    :param dag: the DAG of which to alter state
    :param execution_date: the execution date from which to start looking(deprecated)
    :param run_id: the run_id to start looking from
    :param commit: commit DAG and tasks to be altered to the database
    :param session: database session
    :return: If commit is true, list of tasks that have been updated,
             otherwise list of tasks that will be updated
    :raises: ValueError if dag or execution_date is invalid
    """
    if not exactly_one(execution_date, run_id):
        return []

    if not dag:
        return []

    if execution_date:
        if not timezone.is_localized(execution_date):
            raise ValueError(f"Received non-localized date {execution_date}")
        dag_run = dag.get_dagrun(execution_date=execution_date)
        if not dag_run:
            raise ValueError(
                f'DagRun with execution_date: {execution_date} not found')
        run_id = dag_run.run_id
    if not run_id:
        raise ValueError(f'Invalid dag_run_id: {run_id}')
    # Mark the dag run to success.
    if commit:
        _set_dag_run_state(dag.dag_id, run_id, DagRunState.SUCCESS, session)

    # Mark all task instances of the dag run to success.
    for task in dag.tasks:
        task.dag = dag
    return set_state(tasks=dag.tasks,
                     dag_run_id=run_id,
                     state=State.SUCCESS,
                     commit=commit,
                     session=session)
Exemplo n.º 6
0
    def set(
        cls,
        key: str,
        value: Any,
        task_id: str,
        dag_id: str,
        execution_date: Optional[datetime.datetime] = None,
        session: Session = NEW_SESSION,
        *,
        run_id: Optional[str] = None,
    ) -> None:
        """:sphinx-autoapi-skip:"""
        if not exactly_one(execution_date is not None, run_id is not None):
            raise ValueError(
                "Exactly one of execution_date or run_id must be passed")

        if run_id == IN_MEMORY_DAGRUN_ID:
            execution_date = _DISTANT_FUTURE
        elif run_id is not None:
            from airflow.models.dagrun import DagRun

            execution_date = (session.query(DagRun.execution_date).filter(
                DagRun.dag_id == dag_id, DagRun.run_id == run_id).scalar())
        else:  # Guarantees execution_date is not None.
            message = "Passing 'execution_date' to 'XCom.set()' is deprecated. Use 'run_id' instead."
            warnings.warn(message, DeprecationWarning, stacklevel=3)

        # Remove duplicate XComs and insert a new one.
        session.query(cls).filter(
            cls.key == key,
            cls.execution_date == execution_date,
            cls.task_id == task_id,
            cls.dag_id == dag_id,
        ).delete()
        new = cast(
            Any, cls
        )(  # Work around Mypy complaining model not defining '__init__'.
            key=key,
            value=cls.serialize_value(value),
            execution_date=execution_date,
            task_id=task_id,
            dag_id=dag_id,
        )
        session.add(new)
        session.flush()
Exemplo n.º 7
0
    def get_one(
        cls,
        execution_date: Optional[pendulum.DateTime] = None,
        key: Optional[str] = None,
        task_id: Optional[Union[str, Iterable[str]]] = None,
        dag_id: Optional[Union[str, Iterable[str]]] = None,
        include_prior_dates: bool = False,
        session: Session = NEW_SESSION,
        *,
        run_id: Optional[str] = None,
    ) -> Optional[Any]:
        """:sphinx-autoapi-skip:"""
        if not exactly_one(execution_date is not None, run_id is not None):
            raise ValueError(
                "Exactly one of execution_date or run_id must be passed")

        if run_id is not None:
            query = cls.get_many(
                run_id=run_id,
                key=key,
                task_ids=task_id,
                dag_ids=dag_id,
                include_prior_dates=include_prior_dates,
                session=session,
            )
        elif execution_date is not None:
            message = "Passing 'execution_date' to 'XCom.get_one()' is deprecated. Use 'run_id' instead."
            warnings.warn(message, PendingDeprecationWarning, stacklevel=3)

            query = cls.get_many(
                execution_date=execution_date,
                key=key,
                task_ids=task_id,
                dag_ids=dag_id,
                include_prior_dates=include_prior_dates,
                session=session,
            )
        else:
            raise RuntimeError("Should not happen?")

        result = query.with_entities(cls.value).first()
        if result:
            return cls.deserialize_value(result)
        return None
Exemplo n.º 8
0
def set_state(
    *,
    tasks: Iterable[BaseOperator],
    dag_run_id: Optional[str] = None,
    execution_date: Optional[datetime] = None,
    upstream: bool = False,
    downstream: bool = False,
    future: bool = False,
    past: bool = False,
    state: TaskInstanceState = TaskInstanceState.SUCCESS,
    commit: bool = False,
    session: SASession = NEW_SESSION,
) -> List[TaskInstance]:
    """
    Set the state of a task instance and if needed its relatives. Can set state
    for future tasks (calculated from run_id) and retroactively
    for past tasks. Will verify integrity of past dag runs in order to create
    tasks that did not exist. It will not create dag runs that are missing
    on the schedule (but it will as for subdag dag runs if needed).

    :param tasks: the iterable of tasks from which to work. task.task.dag needs to be set
    :param dag_run_id: the run_id of the dagrun to start looking from
    :param execution_date: the execution date from which to start looking(deprecated)
    :param upstream: Mark all parents (upstream tasks)
    :param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags
    :param future: Mark all future tasks on the interval of the dag up until
        last execution date.
    :param past: Retroactively mark all tasks starting from start_date of the DAG
    :param state: State to which the tasks need to be set
    :param commit: Commit tasks to be altered to the database
    :param session: database session
    :return: list of tasks that have been created and updated
    """
    if not tasks:
        return []

    if not exactly_one(execution_date, dag_run_id):
        raise ValueError(
            "Exactly one of dag_run_id and execution_date must be set")

    if execution_date and not timezone.is_localized(execution_date):
        raise ValueError(f"Received non-localized date {execution_date}")

    task_dags = {task.dag for task in tasks}
    if len(task_dags) > 1:
        raise ValueError(f"Received tasks from multiple DAGs: {task_dags}")
    dag = next(iter(task_dags))
    if dag is None:
        raise ValueError("Received tasks with no DAG")

    if execution_date:
        dag_run_id = dag.get_dagrun(execution_date=execution_date).run_id
    if not dag_run_id:
        raise ValueError("Received tasks with no dag_run_id")

    dag_run_ids = get_run_ids(dag, dag_run_id, future, past)

    task_ids = list(find_task_relatives(tasks, downstream, upstream))

    confirmed_infos = list(_iter_existing_dag_run_infos(dag, dag_run_ids))
    confirmed_dates = [info.logical_date for info in confirmed_infos]

    sub_dag_run_ids = list(
        _iter_subdag_run_ids(dag, session, DagRunState(state), task_ids,
                             commit, confirmed_infos), )

    # now look for the task instances that are affected

    qry_dag = get_all_dag_task_query(dag, session, state, task_ids,
                                     confirmed_dates)

    if commit:
        tis_altered = qry_dag.with_for_update().all()
        if sub_dag_run_ids:
            qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session,
                                                 state, confirmed_dates)
            tis_altered += qry_sub_dag.with_for_update().all()
        for task_instance in tis_altered:
            task_instance.set_state(state)
    else:
        tis_altered = qry_dag.all()
        if sub_dag_run_ids:
            qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session,
                                                 state, confirmed_dates)
            tis_altered += qry_sub_dag.all()
    return tis_altered
Exemplo n.º 9
0
def set_dag_run_state_to_failed(
    *,
    dag: DAG,
    execution_date: Optional[datetime] = None,
    run_id: Optional[str] = None,
    commit: bool = False,
    session: SASession = NEW_SESSION,
):
    """
    Set the dag run for a specific execution date or run_id and its running task instances
    to failed.

    :param dag: the DAG of which to alter state
    :param execution_date: the execution date from which to start looking(deprecated)
    :param run_id: the DAG run_id to start looking from
    :param commit: commit DAG and tasks to be altered to the database
    :param session: database session
    :return: If commit is true, list of tasks that have been updated,
             otherwise list of tasks that will be updated
    :raises: AssertionError if dag or execution_date is invalid
    """
    if not exactly_one(execution_date, run_id):
        return []
    if not dag:
        return []

    if execution_date:
        if not timezone.is_localized(execution_date):
            raise ValueError(f"Received non-localized date {execution_date}")
        dag_run = dag.get_dagrun(execution_date=execution_date)
        if not dag_run:
            raise ValueError(
                f'DagRun with execution_date: {execution_date} not found')
        run_id = dag_run.run_id

    if not run_id:
        raise ValueError(f'Invalid dag_run_id: {run_id}')

    # Mark the dag run to failed.
    if commit:
        _set_dag_run_state(dag.dag_id, run_id, DagRunState.FAILED, session)

    # Mark only RUNNING task instances.
    task_ids = [task.task_id for task in dag.tasks]
    tis = session.query(TaskInstance).filter(
        TaskInstance.dag_id == dag.dag_id,
        TaskInstance.run_id == run_id,
        TaskInstance.task_id.in_(task_ids),
        TaskInstance.state.in_(State.running),
    )
    task_ids_of_running_tis = [task_instance.task_id for task_instance in tis]

    tasks = []
    for task in dag.tasks:
        if task.task_id not in task_ids_of_running_tis:
            continue
        task.dag = dag
        tasks.append(task)

    return set_state(tasks=tasks,
                     dag_run_id=run_id,
                     state=State.FAILED,
                     commit=commit,
                     session=session)
Exemplo n.º 10
0
    def set(
        cls,
        key: str,
        value: Any,
        task_id: str,
        dag_id: str,
        execution_date: Optional[datetime.datetime] = None,
        session: Session = NEW_SESSION,
        *,
        run_id: Optional[str] = None,
    ) -> None:
        """:sphinx-autoapi-skip:"""
        from airflow.models.dagrun import DagRun

        if not exactly_one(execution_date is not None, run_id is not None):
            raise ValueError(
                "Exactly one of run_id or execution_date must be passed")

        if run_id is None:
            message = "Passing 'execution_date' to 'XCom.set()' is deprecated. Use 'run_id' instead."
            warnings.warn(message, DeprecationWarning, stacklevel=3)
            try:
                dagrun_id, run_id = (session.query(
                    DagRun.id, DagRun.run_id).filter(
                        DagRun.dag_id == dag_id,
                        DagRun.execution_date == execution_date).one())
            except NoResultFound:
                raise ValueError(
                    f"DAG run not found on DAG {dag_id!r} at {execution_date}"
                ) from None
        elif run_id == IN_MEMORY_DAGRUN_ID:
            dagrun_id = -1
        else:
            dagrun_id = session.query(DagRun.id).filter_by(
                dag_id=dag_id, run_id=run_id).scalar()
            if dagrun_id is None:
                raise ValueError(
                    f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}")

        value = cls.serialize_value(
            value=value,
            key=key,
            task_id=task_id,
            dag_id=dag_id,
            run_id=dagrun_id,
        )

        # Remove duplicate XComs and insert a new one.
        session.query(cls).filter(
            cls.key == key,
            cls.run_id == run_id,
            cls.task_id == task_id,
            cls.dag_id == dag_id,
        ).delete()
        new = cast(
            Any, cls
        )(  # Work around Mypy complaining model not defining '__init__'.
            dagrun_id=dagrun_id,
            key=key,
            value=value,
            run_id=run_id,
            task_id=task_id,
            dag_id=dag_id,
        )
        session.add(new)
        session.flush()
Exemplo n.º 11
0
 def validate_form(self, data, **kwargs):
     """Validates set task instance state form"""
     if not exactly_one(data.get("execution_date"), data.get("dag_run_id")):
         raise ValidationError(
             "Exactly one of execution_date or dag_run_id must be provided")