Beispiel #1
0
def _trigger_dag(
        dag_id,
        dag_bag,
        dag_run,
        run_id,
        conf,
        execution_date,
        replace_microseconds,
):
    if dag_id not in dag_bag.dags:
        raise DagNotFound("Dag id {} not found".format(dag_id))

    dag = dag_bag.get_dag(dag_id)

    if not execution_date:
        execution_date = timezone.utcnow()

    assert timezone.is_localized(execution_date)

    if replace_microseconds:
        execution_date = execution_date.replace(microsecond=0)

    if not run_id:
        run_id = "manual__{0}".format(execution_date.isoformat())

    dr = dag_run.find(dag_id=dag_id, run_id=run_id)
    if dr:
        raise DagRunAlreadyExists("Run id {} already exists for dag id {}".format(
            run_id,
            dag_id
        ))

    run_conf = None
    if conf:
        if type(conf) is dict:
            run_conf = conf
        else:
            run_conf = json.loads(conf)

    triggers = list()
    dags_to_trigger = list()
    dags_to_trigger.append(dag)
    while dags_to_trigger:
        dag = dags_to_trigger.pop()
        trigger = dag.create_dagrun(
            run_id=run_id,
            execution_date=execution_date,
            state=State.RUNNING,
            conf=run_conf,
            external_trigger=True,
        )
        triggers.append(trigger)
        if dag.subdags:
            dags_to_trigger.extend(dag.subdags)
    return triggers
Beispiel #2
0
def trigger_dag(dag_id, run_id=None, conf=None, execution_date=None,
                replace_microseconds=True):
    dagbag = DagBag()

    if dag_id not in dagbag.dags:
        raise AirflowException("Dag id {} not found".format(dag_id))

    dag = dagbag.get_dag(dag_id)

    if not execution_date:
        execution_date = timezone.utcnow()

    assert timezone.is_localized(execution_date)

    if replace_microseconds:
        execution_date = execution_date.replace(microsecond=0)

    if not run_id:
        run_id = "manual__{0}".format(execution_date.isoformat())

    dr = DagRun.find(dag_id=dag_id, run_id=run_id)
    if dr:
        raise AirflowException("Run id {} already exists for dag id {}".format(
            run_id,
            dag_id
        ))

    run_conf = None
    if conf:
        run_conf = json.loads(conf)

    trigger = dag.create_dagrun(
        run_id=run_id,
        execution_date=execution_date,
        state=State.RUNNING,
        conf=run_conf,
        external_trigger=True
    )

    return trigger
 def test_utcnow(self):
     now = timezone.utcnow()
     self.assertTrue(timezone.is_localized(now))
     self.assertEquals(now.replace(tzinfo=None), now.astimezone(UTC).replace(tzinfo=None))
 def test_is_aware(self):
     self.assertTrue(timezone.is_localized(datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT)))
     self.assertFalse(timezone.is_localized(datetime.datetime(2011, 9, 1, 13, 20, 30)))
Beispiel #5
0
def _trigger_dag(
        dag_id,  # type: str
        dag_bag,  # type: DagBag
        dag_run,  # type: DagModel
        run_id,  # type: Optional[str]
        conf,  # type: Optional[Union[dict, str]]
        execution_date,  # type: Optional[datetime]
        replace_microseconds,  # type: bool
):  # pylint: disable=too-many-arguments
    # type: (...) -> List[DagRun]
    """Triggers DAG run.

    :param dag_id: DAG ID
    :param dag_bag: DAG Bag model
    :param dag_run: DAG Run model
    :param run_id: ID of the dag_run
    :param conf: configuration
    :param execution_date: date of execution
    :param replace_microseconds: whether microseconds should be zeroed
    :return: list of triggered dags
    """
    if dag_id not in dag_bag.dags:
        raise DagNotFound("Dag id {} not found".format(dag_id))

    dag = dag_bag.get_dag(dag_id)

    execution_date = execution_date if execution_date else timezone.utcnow()

    assert timezone.is_localized(execution_date)

    if replace_microseconds:
        execution_date = execution_date.replace(microsecond=0)

    if not run_id:
        run_id = "manual__{0}".format(execution_date.isoformat())

    dag_run_id = dag_run.find(dag_id=dag_id, run_id=run_id)
    if dag_run_id:
        raise DagRunAlreadyExists(
            "Run id {} already exists for dag id {}".format(run_id, dag_id))

    run_conf = None
    if conf:
        if isinstance(conf, dict):
            run_conf = conf
        else:
            run_conf = json.loads(conf)

    triggers = list()
    dags_to_trigger = list()
    dags_to_trigger.append(dag)
    while dags_to_trigger:
        dag = dags_to_trigger.pop()
        trigger = dag.create_dagrun(
            run_id=run_id,
            execution_date=execution_date,
            state=State.RUNNING,
            conf=run_conf,
            external_trigger=True,
        )
        triggers.append(trigger)
        if dag.subdags:
            dags_to_trigger.extend(dag.subdags)
    return triggers
Beispiel #6
0
def set_state(tasks: Iterable[BaseOperator],
              execution_date: datetime.datetime,
              upstream: bool = False,
              downstream: bool = False,
              future: bool = False,
              past: bool = False,
              state: str = State.SUCCESS,
              commit: bool = False,
              session=None):  # pylint: disable=too-many-arguments,too-many-locals
    """
    Set the state of a task instance and if needed its relatives. Can set state
    for future tasks (calculated from execution_date) and retroactively
    for past tasks. Will verify integrity of past dag runs in order to create
    tasks that did not exist. It will not create dag runs that are missing
    on the schedule (but it will as for subdag dag runs if needed).

    :param tasks: the iterable of tasks from which to work. task.task.dag needs to be set
    :param execution_date: the execution date from which to start looking
    :param upstream: Mark all parents (upstream tasks)
    :param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags
    :param future: Mark all future tasks on the interval of the dag up until
        last execution date.
    :param past: Retroactively mark all tasks starting from start_date of the DAG
    :param state: State to which the tasks need to be set
    :param commit: Commit tasks to be altered to the database
    :param session: database session
    :return: list of tasks that have been created and updated
    """
    if not tasks:
        return []

    if not timezone.is_localized(execution_date):
        raise ValueError(
            "Received non-localized date {}".format(execution_date))

    task_dags = {task.dag for task in tasks}
    if len(task_dags) > 1:
        raise ValueError(
            "Received tasks from multiple DAGs: {}".format(task_dags))
    dag = next(iter(task_dags))
    if dag is None:
        raise ValueError("Received tasks with no DAG")

    dates = get_execution_dates(dag, execution_date, future, past)

    task_ids = list(find_task_relatives(tasks, downstream, upstream))

    confirmed_dates = verify_dag_run_integrity(dag, dates)

    sub_dag_run_ids = get_subdag_runs(dag, session, state, task_ids, commit,
                                      confirmed_dates)

    # now look for the task instances that are affected

    qry_dag = get_all_dag_task_query(dag, session, state, task_ids,
                                     confirmed_dates)

    if commit:
        tis_altered = qry_dag.with_for_update().all()
        if sub_dag_run_ids:
            qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session,
                                                 state, confirmed_dates)
            tis_altered += qry_sub_dag.with_for_update().all()
        for task_instance in tis_altered:
            task_instance.state = state
            if state in State.finished():
                task_instance.end_date = timezone.utcnow()
                task_instance.set_duration()
    else:
        tis_altered = qry_dag.all()
        if sub_dag_run_ids:
            qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session,
                                                 state, confirmed_dates)
            tis_altered += qry_sub_dag.all()

    return tis_altered
Beispiel #7
0
def _trigger_dag(
    dag_id: str,
    dag_bag: DagBag,
    run_id: Optional[str] = None,
    conf: Optional[Union[dict, str]] = None,
    execution_date: Optional[datetime] = None,
    replace_microseconds: bool = True,
) -> List[Optional[DagRun]]:
    """Triggers DAG run.

    :param dag_id: DAG ID
    :param dag_bag: DAG Bag model
    :param run_id: ID of the dag_run
    :param conf: configuration
    :param execution_date: date of execution
    :param replace_microseconds: whether microseconds should be zeroed
    :return: list of triggered dags
    """
    dag = dag_bag.get_dag(dag_id)  # prefetch dag if it is stored serialized

    if dag is None or dag_id not in dag_bag.dags:
        raise DagNotFound(f"Dag id {dag_id} not found")

    execution_date = execution_date if execution_date else timezone.utcnow()

    if not timezone.is_localized(execution_date):
        raise ValueError("The execution_date should be localized")

    if replace_microseconds:
        execution_date = execution_date.replace(microsecond=0)

    if dag.default_args and 'start_date' in dag.default_args:
        min_dag_start_date = dag.default_args["start_date"]
        if min_dag_start_date and execution_date < min_dag_start_date:
            raise ValueError(
                f"The execution_date [{execution_date.isoformat()}] should be >= start_date "
                f"[{min_dag_start_date.isoformat()}] from DAG's default_args"
            )
    logical_date = timezone.coerce_datetime(execution_date)

    data_interval = dag.timetable.infer_manual_data_interval(run_after=logical_date)
    run_id = run_id or dag.timetable.generate_run_id(
        run_type=DagRunType.MANUAL, logical_date=logical_date, data_interval=data_interval
    )
    dag_run = DagRun.find_duplicate(dag_id=dag_id, execution_date=execution_date, run_id=run_id)

    if dag_run:
        raise DagRunAlreadyExists(
            f"A Dag Run already exists for dag id {dag_id} at {execution_date} with run id {run_id}"
        )

    run_conf = None
    if conf:
        run_conf = conf if isinstance(conf, dict) else json.loads(conf)

    dag_runs = []
    dags_to_run = [dag] + dag.subdags
    for _dag in dags_to_run:
        dag_run = _dag.create_dagrun(
            run_id=run_id,
            execution_date=execution_date,
            state=DagRunState.QUEUED,
            conf=run_conf,
            external_trigger=True,
            dag_hash=dag_bag.dags_hash.get(dag_id),
            data_interval=data_interval,
        )
        dag_runs.append(dag_run)

    return dag_runs
Beispiel #8
0
def set_state(
    *,
    tasks: Union[Collection[Operator], Collection[Tuple[Operator, int]]],
    run_id: Optional[str] = None,
    execution_date: Optional[datetime] = None,
    upstream: bool = False,
    downstream: bool = False,
    future: bool = False,
    past: bool = False,
    state: TaskInstanceState = TaskInstanceState.SUCCESS,
    commit: bool = False,
    session: SASession = NEW_SESSION,
) -> List[TaskInstance]:
    """
    Set the state of a task instance and if needed its relatives. Can set state
    for future tasks (calculated from run_id) and retroactively
    for past tasks. Will verify integrity of past dag runs in order to create
    tasks that did not exist. It will not create dag runs that are missing
    on the schedule (but it will as for subdag dag runs if needed).

    :param tasks: the iterable of tasks or (task, map_index) tuples from which to work.
        task.task.dag needs to be set
    :param run_id: the run_id of the dagrun to start looking from
    :param execution_date: the execution date from which to start looking(deprecated)
    :param upstream: Mark all parents (upstream tasks)
    :param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags
    :param future: Mark all future tasks on the interval of the dag up until
        last execution date.
    :param past: Retroactively mark all tasks starting from start_date of the DAG
    :param state: State to which the tasks need to be set
    :param commit: Commit tasks to be altered to the database
    :param session: database session
    :return: list of tasks that have been created and updated
    """
    if not tasks:
        return []

    if not exactly_one(execution_date, run_id):
        raise ValueError(
            "Exactly one of dag_run_id and execution_date must be set")

    if execution_date and not timezone.is_localized(execution_date):
        raise ValueError(f"Received non-localized date {execution_date}")

    task_dags = {
        task[0].dag if isinstance(task, tuple) else task.dag
        for task in tasks
    }
    if len(task_dags) > 1:
        raise ValueError(f"Received tasks from multiple DAGs: {task_dags}")
    dag = next(iter(task_dags))
    if dag is None:
        raise ValueError("Received tasks with no DAG")

    if execution_date:
        run_id = dag.get_dagrun(execution_date=execution_date).run_id
    if not run_id:
        raise ValueError("Received tasks with no run_id")

    dag_run_ids = get_run_ids(dag, run_id, future, past)
    task_id_map_index_list = list(
        find_task_relatives(tasks, downstream, upstream))
    task_ids = [task_id for task_id, _ in task_id_map_index_list]
    # check if task_id_map_index_list contains map_index of None
    # if it contains None, there was no map_index supplied for the task
    for _, index in task_id_map_index_list:
        if index is None:
            task_id_map_index_list = [
                task_id for task_id, _ in task_id_map_index_list
            ]
            break

    confirmed_infos = list(_iter_existing_dag_run_infos(dag, dag_run_ids))
    confirmed_dates = [info.logical_date for info in confirmed_infos]

    sub_dag_run_ids = list(
        _iter_subdag_run_ids(dag, session, DagRunState(state), task_ids,
                             commit, confirmed_infos), )

    # now look for the task instances that are affected

    qry_dag = get_all_dag_task_query(dag, session, state,
                                     task_id_map_index_list, confirmed_dates)

    if commit:
        tis_altered = qry_dag.with_for_update().all()
        if sub_dag_run_ids:
            qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session,
                                                 state, confirmed_dates)
            tis_altered += qry_sub_dag.with_for_update().all()
        for task_instance in tis_altered:
            task_instance.set_state(state)
    else:
        tis_altered = qry_dag.all()
        if sub_dag_run_ids:
            qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session,
                                                 state, confirmed_dates)
            tis_altered += qry_sub_dag.all()
    return tis_altered
Beispiel #9
0
def date_range(
    start_date: datetime,
    end_date: Optional[datetime] = None,
    num: Optional[int] = None,
    delta: Optional[Union[str, timedelta, relativedelta]] = None,
) -> List[datetime]:
    """
    Get a set of dates as a list based on a start, end and delta, delta
    can be something that can be added to `datetime.datetime`
    or a cron expression as a `str`

    .. code-block:: python

        date_range(datetime(2016, 1, 1), datetime(2016, 1, 3), delta=timedelta(1))
            [datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2016, 1, 2, 0, 0),
            datetime.datetime(2016, 1, 3, 0, 0)]
        date_range(datetime(2016, 1, 1), datetime(2016, 1, 3), delta='0 0 * * *')
            [datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2016, 1, 2, 0, 0),
            datetime.datetime(2016, 1, 3, 0, 0)]
        date_range(datetime(2016, 1, 1), datetime(2016, 3, 3), delta="0 0 0 * *")
            [datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2016, 2, 1, 0, 0),
            datetime.datetime(2016, 3, 1, 0, 0)]

    :param start_date: anchor date to start the series from
    :type start_date: datetime.datetime
    :param end_date: right boundary for the date range
    :type end_date: datetime.datetime
    :param num: alternatively to end_date, you can specify the number of
        number of entries you want in the range. This number can be negative,
        output will always be sorted regardless
    :type num: int
    :param delta: step length. It can be datetime.timedelta or cron expression as string
    :type delta: datetime.timedelta or str or dateutil.relativedelta
    """
    if not delta:
        return []
    if end_date:
        if start_date > end_date:
            raise Exception("Wait. start_date needs to be before end_date")
        if num:
            raise Exception("Wait. Either specify end_date OR num")
    if not end_date and not num:
        end_date = timezone.utcnow()

    delta_iscron = False
    time_zone = start_date.tzinfo

    abs_delta: Union[timedelta, relativedelta]
    if isinstance(delta, str):
        delta_iscron = True
        if timezone.is_localized(start_date):
            start_date = timezone.make_naive(start_date, time_zone)
        cron = croniter(cron_presets.get(delta, delta), start_date)
    elif isinstance(delta, timedelta):
        abs_delta = abs(delta)
    elif isinstance(delta, relativedelta):
        abs_delta = abs(delta)
    else:
        raise Exception("Wait. delta must be either datetime.timedelta or cron expression as str")

    dates = []
    if end_date:
        if timezone.is_naive(start_date) and not timezone.is_naive(end_date):
            end_date = timezone.make_naive(end_date, time_zone)
        while start_date <= end_date:  # type: ignore
            if timezone.is_naive(start_date):
                dates.append(timezone.make_aware(start_date, time_zone))
            else:
                dates.append(start_date)

            if delta_iscron:
                start_date = cron.get_next(datetime)
            else:
                start_date += abs_delta
    else:
        num_entries: int = num  # type: ignore
        for _ in range(abs(num_entries)):
            if timezone.is_naive(start_date):
                dates.append(timezone.make_aware(start_date, time_zone))
            else:
                dates.append(start_date)

            if delta_iscron and num_entries > 0:
                start_date = cron.get_next(datetime)
            elif delta_iscron:
                start_date = cron.get_prev(datetime)
            elif num_entries > 0:
                start_date += abs_delta
            else:
                start_date -= abs_delta

    return sorted(dates)
def _trigger_dag(
    dag_id: str,
    dag_bag: DagBag,
    run_id: Optional[str] = None,
    conf: Optional[Union[dict, str]] = None,
    execution_date: Optional[datetime] = None,
    replace_microseconds: bool = True,
) -> List[DagRun]:
    """Triggers DAG run.

    :param dag_id: DAG ID
    :param dag_bag: DAG Bag model
    :param run_id: ID of the dag_run
    :param conf: configuration
    :param execution_date: date of execution
    :param replace_microseconds: whether microseconds should be zeroed
    :return: list of triggered dags
    """
    dag = dag_bag.get_dag(dag_id)  # prefetch dag if it is stored serialized

    if dag_id not in dag_bag.dags:
        raise DagNotFound(f"Dag id {dag_id} not found")

    execution_date = execution_date if execution_date else timezone.utcnow()

    if not timezone.is_localized(execution_date):
        raise ValueError("The execution_date should be localized")

    if replace_microseconds:
        execution_date = execution_date.replace(microsecond=0)

    if dag.default_args and 'start_date' in dag.default_args:
        min_dag_start_date = dag.default_args["start_date"]
        if min_dag_start_date and execution_date < min_dag_start_date:
            raise ValueError(
                "The execution_date [{}] should be >= start_date [{}] from DAG's default_args"
                .format(execution_date.isoformat(),
                        min_dag_start_date.isoformat()))

    run_id = run_id or DagRun.generate_run_id(DagRunType.MANUAL,
                                              execution_date)
    dag_run = DagRun.find(dag_id=dag_id, run_id=run_id)

    if dag_run:
        raise DagRunAlreadyExists(
            f"Run id {run_id} already exists for dag id {dag_id}")

    run_conf = None
    if conf:
        run_conf = conf if isinstance(conf, dict) else json.loads(conf)

    triggers = []
    dags_to_trigger = [dag] + dag.subdags
    for _dag in dags_to_trigger:
        trigger = _dag.create_dagrun(
            run_id=run_id,
            execution_date=execution_date,
            state=State.RUNNING,
            conf=run_conf,
            external_trigger=True,
            dag_hash=dag_bag.dags_hash.get(dag_id),
        )

        triggers.append(trigger)
    return triggers
Beispiel #11
0
 def test_utcnow(self):
     now = timezone.utcnow()
     assert timezone.is_localized(now)
     assert now.replace(tzinfo=None) == now.astimezone(UTC).replace(tzinfo=None)
Beispiel #12
0
 def test_is_aware(self):
     assert timezone.is_localized(datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT))
     assert not timezone.is_localized(datetime.datetime(2011, 9, 1, 13, 20, 30))
Beispiel #13
0
def _trigger_dag(
        dag_id,  # type: str
        dag_bag,  # type: DagBag
        dag_run,  # type: DagModel
        run_id,  # type: Optional[str]
        conf,  # type: Optional[Union[dict, str]]
        execution_date,  # type: Optional[datetime]
        replace_microseconds,  # type: bool
):  # pylint: disable=too-many-arguments
    # type: (...) -> List[DagRun]
    """Triggers DAG run.

    :param dag_id: DAG ID
    :param dag_bag: DAG Bag model
    :param dag_run: DAG Run model
    :param run_id: ID of the dag_run
    :param conf: configuration
    :param execution_date: date of execution
    :param replace_microseconds: whether microseconds should be zeroed
    :return: list of triggered dags
    """
    dag = dag_bag.get_dag(dag_id)  # prefetch dag if it is stored serialized

    if dag_id not in dag_bag.dags:
        raise DagNotFound("Dag id {} not found".format(dag_id))

    execution_date = execution_date if execution_date else timezone.utcnow()

    assert timezone.is_localized(execution_date)

    if replace_microseconds:
        execution_date = execution_date.replace(microsecond=0)

    if dag.default_args and 'start_date' in dag.default_args:
        min_dag_start_date = dag.default_args["start_date"]
        if min_dag_start_date and execution_date < min_dag_start_date:
            raise ValueError(
                "The execution_date [{0}] should be >= start_date [{1}] from DAG's default_args".format(
                    execution_date.isoformat(),
                    min_dag_start_date.isoformat()))

    if not run_id:
        run_id = "manual__{0}".format(execution_date.isoformat())

    dag_run_id = dag_run.find(dag_id=dag_id, run_id=run_id)
    if dag_run_id:
        raise DagRunAlreadyExists("Run id {} already exists for dag id {}".format(
            run_id,
            dag_id
        ))

    run_conf = None
    if conf:
        if isinstance(conf, dict):
            run_conf = conf
        else:
            run_conf = json.loads(conf)

    triggers = []
    dags_to_trigger = [dag] + dag.subdags
    for _dag in dags_to_trigger:
        trigger = _dag.create_dagrun(
            run_id=run_id,
            execution_date=execution_date,
            state=State.RUNNING,
            conf=run_conf,
            external_trigger=True,
        )
        triggers.append(trigger)
    return triggers
Beispiel #14
0
 def test_utcnow(self):
     now = timezone.utcnow()
     self.assertTrue(timezone.is_localized(now))
     self.assertEquals(now.replace(tzinfo=None),
                       now.astimezone(UTC).replace(tzinfo=None))
Beispiel #15
0
 def test_is_aware(self):
     self.assertTrue(
         timezone.is_localized(
             datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT)))
     self.assertFalse(
         timezone.is_localized(datetime.datetime(2011, 9, 1, 13, 20, 30)))
Beispiel #16
0
def set_state(task,
              execution_date,
              upstream=False,
              downstream=False,
              future=False,
              past=False,
              state=State.SUCCESS,
              commit=False):
    """
    Set the state of a task instance and if needed its relatives. Can set state
    for future tasks (calculated from execution_date) and retroactively
    for past tasks. Will verify integrity of past dag runs in order to create
    tasks that did not exist. It will not create dag runs that are missing
    on the schedule (but it will as for subdag dag runs if needed).
    :param task: the task from which to work. task.task.dag needs to be set
    :param execution_date: the execution date from which to start looking
    :param upstream: Mark all parents (upstream tasks)
    :param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags
    :param future: Mark all future tasks on the interval of the dag up until
        last execution date.
    :param past: Retroactively mark all tasks starting from start_date of the DAG
    :param state: State to which the tasks need to be set
    :param commit: Commit tasks to be altered to the database
    :return: list of tasks that have been created and updated
    """
    assert timezone.is_localized(execution_date)

    # microseconds are supported by the database, but is not handled
    # correctly by airflow on e.g. the filesystem and in other places
    execution_date = execution_date.replace(microsecond=0)

    assert task.dag is not None
    dag = task.dag

    latest_execution_date = dag.latest_execution_date
    assert latest_execution_date is not None

    # determine date range of dag runs and tasks to consider
    end_date = latest_execution_date if future else execution_date

    if 'start_date' in dag.default_args:
        start_date = dag.default_args['start_date']
    elif dag.start_date:
        start_date = dag.start_date
    else:
        start_date = execution_date

    start_date = execution_date if not past else start_date

    if dag.schedule_interval == '@once':
        dates = [start_date]
    else:
        dates = dag.date_range(start_date=start_date, end_date=end_date)

    # find relatives (siblings = downstream, parents = upstream) if needed
    task_ids = [task.task_id]
    if downstream:
        relatives = task.get_flat_relatives(upstream=False)
        task_ids += [t.task_id for t in relatives]
    if upstream:
        relatives = task.get_flat_relatives(upstream=True)
        task_ids += [t.task_id for t in relatives]

    # verify the integrity of the dag runs in case a task was added or removed
    # set the confirmed execution dates as they might be different
    # from what was provided
    confirmed_dates = []
    drs = DagRun.find(dag_id=dag.dag_id, execution_date=dates)
    for dr in drs:
        dr.dag = dag
        dr.verify_integrity()
        confirmed_dates.append(dr.execution_date)

    # go through subdagoperators and create dag runs. We will only work
    # within the scope of the subdag. We wont propagate to the parent dag,
    # but we will propagate from parent to subdag.
    session = Session()
    dags = [dag]
    sub_dag_ids = []
    while len(dags) > 0:
        current_dag = dags.pop()
        for task_id in task_ids:
            if not current_dag.has_task(task_id):
                continue

            current_task = current_dag.get_task(task_id)
            if isinstance(current_task, SubDagOperator):
                # this works as a kind of integrity check
                # it creates missing dag runs for subdagoperators,
                # maybe this should be moved to dagrun.verify_integrity
                drs = _create_dagruns(
                    current_task.subdag,
                    execution_dates=confirmed_dates,
                    state=State.RUNNING,
                    run_id_template=BackfillJob.ID_FORMAT_PREFIX)

                for dr in drs:
                    dr.dag = current_task.subdag
                    dr.verify_integrity()
                    if commit:
                        dr.state = state
                        session.merge(dr)

                dags.append(current_task.subdag)
                sub_dag_ids.append(current_task.subdag.dag_id)

    # now look for the task instances that are affected
    TI = TaskInstance

    # get all tasks of the main dag that will be affected by a state change
    qry_dag = session.query(TI).filter(TI.dag_id == dag.dag_id,
                                       TI.execution_date.in_(confirmed_dates),
                                       TI.task_id.in_(task_ids)).filter(
                                           or_(TI.state.is_(None),
                                               TI.state != state))

    # get *all* tasks of the sub dags
    if len(sub_dag_ids) > 0:
        qry_sub_dag = session.query(TI).filter(
            TI.dag_id.in_(sub_dag_ids),
            TI.execution_date.in_(confirmed_dates)).filter(
                or_(TI.state.is_(None), TI.state != state))

    if commit:
        tis_altered = qry_dag.with_for_update().all()
        if len(sub_dag_ids) > 0:
            tis_altered += qry_sub_dag.with_for_update().all()
        for ti in tis_altered:
            ti.state = state
        session.commit()
    else:
        tis_altered = qry_dag.all()
        if len(sub_dag_ids) > 0:
            tis_altered += qry_sub_dag.all()

    session.expunge_all()
    session.close()

    return tis_altered
Beispiel #17
0
def set_state(task, execution_date, upstream=False, downstream=False,
              future=False, past=False, state=State.SUCCESS, commit=False):
    """
    Set the state of a task instance and if needed its relatives. Can set state
    for future tasks (calculated from execution_date) and retroactively
    for past tasks. Will verify integrity of past dag runs in order to create
    tasks that did not exist. It will not create dag runs that are missing
    on the schedule (but it will as for subdag dag runs if needed).
    :param task: the task from which to work. task.task.dag needs to be set
    :param execution_date: the execution date from which to start looking
    :param upstream: Mark all parents (upstream tasks)
    :param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags
    :param future: Mark all future tasks on the interval of the dag up until
        last execution date.
    :param past: Retroactively mark all tasks starting from start_date of the DAG
    :param state: State to which the tasks need to be set
    :param commit: Commit tasks to be altered to the database
    :return: list of tasks that have been created and updated
    """
    assert timezone.is_localized(execution_date)

    # microseconds are supported by the database, but is not handled
    # correctly by airflow on e.g. the filesystem and in other places
    execution_date = execution_date.replace(microsecond=0)

    assert task.dag is not None
    dag = task.dag

    latest_execution_date = dag.latest_execution_date
    assert latest_execution_date is not None

    # determine date range of dag runs and tasks to consider
    end_date = latest_execution_date if future else execution_date

    if 'start_date' in dag.default_args:
        start_date = dag.default_args['start_date']
    elif dag.start_date:
        start_date = dag.start_date
    else:
        start_date = execution_date

    start_date = execution_date if not past else start_date

    if dag.schedule_interval == '@once':
        dates = [start_date]
    else:
        dates = dag.date_range(start_date=start_date, end_date=end_date)

    # find relatives (siblings = downstream, parents = upstream) if needed
    task_ids = [task.task_id]
    if downstream:
        relatives = task.get_flat_relatives(upstream=False)
        task_ids += [t.task_id for t in relatives]
    if upstream:
        relatives = task.get_flat_relatives(upstream=True)
        task_ids += [t.task_id for t in relatives]

    # verify the integrity of the dag runs in case a task was added or removed
    # set the confirmed execution dates as they might be different
    # from what was provided
    confirmed_dates = []
    drs = DagRun.find(dag_id=dag.dag_id, execution_date=dates)
    for dr in drs:
        dr.dag = dag
        dr.verify_integrity()
        confirmed_dates.append(dr.execution_date)

    # go through subdagoperators and create dag runs. We will only work
    # within the scope of the subdag. We wont propagate to the parent dag,
    # but we will propagate from parent to subdag.
    session = Session()
    dags = [dag]
    sub_dag_ids = []
    while len(dags) > 0:
        current_dag = dags.pop()
        for task_id in task_ids:
            if not current_dag.has_task(task_id):
                continue

            current_task = current_dag.get_task(task_id)
            if isinstance(current_task, SubDagOperator):
                # this works as a kind of integrity check
                # it creates missing dag runs for subdagoperators,
                # maybe this should be moved to dagrun.verify_integrity
                drs = _create_dagruns(current_task.subdag,
                                      execution_dates=confirmed_dates,
                                      state=State.RUNNING,
                                      run_id_template=BackfillJob.ID_FORMAT_PREFIX)

                for dr in drs:
                    dr.dag = current_task.subdag
                    dr.verify_integrity()
                    if commit:
                        dr.state = state
                        session.merge(dr)

                dags.append(current_task.subdag)
                sub_dag_ids.append(current_task.subdag.dag_id)

    # now look for the task instances that are affected
    TI = TaskInstance

    # get all tasks of the main dag that will be affected by a state change
    qry_dag = session.query(TI).filter(
        TI.dag_id==dag.dag_id,
        TI.execution_date.in_(confirmed_dates),
        TI.task_id.in_(task_ids)).filter(
        or_(TI.state.is_(None),
            TI.state != state)
    )

    # get *all* tasks of the sub dags
    if len(sub_dag_ids) > 0:
        qry_sub_dag = session.query(TI).filter(
            TI.dag_id.in_(sub_dag_ids),
            TI.execution_date.in_(confirmed_dates)).filter(
            or_(TI.state.is_(None),
                TI.state != state)
        )

    if commit:
        tis_altered = qry_dag.with_for_update().all()
        if len(sub_dag_ids) > 0:
            tis_altered += qry_sub_dag.with_for_update().all()
        for ti in tis_altered:
            ti.state = state
        session.commit()
    else:
        tis_altered = qry_dag.all()
        if len(sub_dag_ids) > 0:
            tis_altered += qry_sub_dag.all()

    session.expunge_all()
    session.close()

    return tis_altered
Beispiel #18
0
def set_dag_run_state_to_failed(
    *,
    dag: DAG,
    execution_date: Optional[datetime] = None,
    run_id: Optional[str] = None,
    commit: bool = False,
    session: SASession = NEW_SESSION,
) -> List[TaskInstance]:
    """
    Set the dag run for a specific execution date or run_id and its running task instances
    to failed.

    :param dag: the DAG of which to alter state
    :param execution_date: the execution date from which to start looking(deprecated)
    :param run_id: the DAG run_id to start looking from
    :param commit: commit DAG and tasks to be altered to the database
    :param session: database session
    :return: If commit is true, list of tasks that have been updated,
             otherwise list of tasks that will be updated
    :raises: AssertionError if dag or execution_date is invalid
    """
    if not exactly_one(execution_date, run_id):
        return []
    if not dag:
        return []

    if execution_date:
        if not timezone.is_localized(execution_date):
            raise ValueError(f"Received non-localized date {execution_date}")
        dag_run = dag.get_dagrun(execution_date=execution_date)
        if not dag_run:
            raise ValueError(
                f'DagRun with execution_date: {execution_date} not found')
        run_id = dag_run.run_id

    if not run_id:
        raise ValueError(f'Invalid dag_run_id: {run_id}')

    # Mark the dag run to failed.
    if commit:
        _set_dag_run_state(dag.dag_id, run_id, DagRunState.FAILED, session)

    # Mark only RUNNING task instances.
    task_ids = [task.task_id for task in dag.tasks]
    tis = session.query(TaskInstance).filter(
        TaskInstance.dag_id == dag.dag_id,
        TaskInstance.run_id == run_id,
        TaskInstance.task_id.in_(task_ids),
        TaskInstance.state.in_(State.running),
    )
    task_ids_of_running_tis = [task_instance.task_id for task_instance in tis]

    tasks = []
    for task in dag.tasks:
        if task.task_id not in task_ids_of_running_tis:
            continue
        task.dag = dag
        tasks.append(task)

    # Mark non-finished tasks as SKIPPED.
    tis = session.query(TaskInstance).filter(
        TaskInstance.dag_id == dag.dag_id,
        TaskInstance.run_id == run_id,
        TaskInstance.state.not_in(State.finished),
        TaskInstance.state.not_in(State.running),
    )

    tis = [ti for ti in tis]
    if commit:
        for ti in tis:
            ti.set_state(State.SKIPPED)

    return tis + set_state(tasks=tasks,
                           run_id=run_id,
                           state=State.FAILED,
                           commit=commit,
                           session=session)