def _trigger_dag( dag_id, dag_bag, dag_run, run_id, conf, execution_date, replace_microseconds, ): if dag_id not in dag_bag.dags: raise DagNotFound("Dag id {} not found".format(dag_id)) dag = dag_bag.get_dag(dag_id) if not execution_date: execution_date = timezone.utcnow() assert timezone.is_localized(execution_date) if replace_microseconds: execution_date = execution_date.replace(microsecond=0) if not run_id: run_id = "manual__{0}".format(execution_date.isoformat()) dr = dag_run.find(dag_id=dag_id, run_id=run_id) if dr: raise DagRunAlreadyExists("Run id {} already exists for dag id {}".format( run_id, dag_id )) run_conf = None if conf: if type(conf) is dict: run_conf = conf else: run_conf = json.loads(conf) triggers = list() dags_to_trigger = list() dags_to_trigger.append(dag) while dags_to_trigger: dag = dags_to_trigger.pop() trigger = dag.create_dagrun( run_id=run_id, execution_date=execution_date, state=State.RUNNING, conf=run_conf, external_trigger=True, ) triggers.append(trigger) if dag.subdags: dags_to_trigger.extend(dag.subdags) return triggers
def trigger_dag(dag_id, run_id=None, conf=None, execution_date=None, replace_microseconds=True): dagbag = DagBag() if dag_id not in dagbag.dags: raise AirflowException("Dag id {} not found".format(dag_id)) dag = dagbag.get_dag(dag_id) if not execution_date: execution_date = timezone.utcnow() assert timezone.is_localized(execution_date) if replace_microseconds: execution_date = execution_date.replace(microsecond=0) if not run_id: run_id = "manual__{0}".format(execution_date.isoformat()) dr = DagRun.find(dag_id=dag_id, run_id=run_id) if dr: raise AirflowException("Run id {} already exists for dag id {}".format( run_id, dag_id )) run_conf = None if conf: run_conf = json.loads(conf) trigger = dag.create_dagrun( run_id=run_id, execution_date=execution_date, state=State.RUNNING, conf=run_conf, external_trigger=True ) return trigger
def test_utcnow(self): now = timezone.utcnow() self.assertTrue(timezone.is_localized(now)) self.assertEquals(now.replace(tzinfo=None), now.astimezone(UTC).replace(tzinfo=None))
def test_is_aware(self): self.assertTrue(timezone.is_localized(datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT))) self.assertFalse(timezone.is_localized(datetime.datetime(2011, 9, 1, 13, 20, 30)))
def _trigger_dag( dag_id, # type: str dag_bag, # type: DagBag dag_run, # type: DagModel run_id, # type: Optional[str] conf, # type: Optional[Union[dict, str]] execution_date, # type: Optional[datetime] replace_microseconds, # type: bool ): # pylint: disable=too-many-arguments # type: (...) -> List[DagRun] """Triggers DAG run. :param dag_id: DAG ID :param dag_bag: DAG Bag model :param dag_run: DAG Run model :param run_id: ID of the dag_run :param conf: configuration :param execution_date: date of execution :param replace_microseconds: whether microseconds should be zeroed :return: list of triggered dags """ if dag_id not in dag_bag.dags: raise DagNotFound("Dag id {} not found".format(dag_id)) dag = dag_bag.get_dag(dag_id) execution_date = execution_date if execution_date else timezone.utcnow() assert timezone.is_localized(execution_date) if replace_microseconds: execution_date = execution_date.replace(microsecond=0) if not run_id: run_id = "manual__{0}".format(execution_date.isoformat()) dag_run_id = dag_run.find(dag_id=dag_id, run_id=run_id) if dag_run_id: raise DagRunAlreadyExists( "Run id {} already exists for dag id {}".format(run_id, dag_id)) run_conf = None if conf: if isinstance(conf, dict): run_conf = conf else: run_conf = json.loads(conf) triggers = list() dags_to_trigger = list() dags_to_trigger.append(dag) while dags_to_trigger: dag = dags_to_trigger.pop() trigger = dag.create_dagrun( run_id=run_id, execution_date=execution_date, state=State.RUNNING, conf=run_conf, external_trigger=True, ) triggers.append(trigger) if dag.subdags: dags_to_trigger.extend(dag.subdags) return triggers
def set_state(tasks: Iterable[BaseOperator], execution_date: datetime.datetime, upstream: bool = False, downstream: bool = False, future: bool = False, past: bool = False, state: str = State.SUCCESS, commit: bool = False, session=None): # pylint: disable=too-many-arguments,too-many-locals """ Set the state of a task instance and if needed its relatives. Can set state for future tasks (calculated from execution_date) and retroactively for past tasks. Will verify integrity of past dag runs in order to create tasks that did not exist. It will not create dag runs that are missing on the schedule (but it will as for subdag dag runs if needed). :param tasks: the iterable of tasks from which to work. task.task.dag needs to be set :param execution_date: the execution date from which to start looking :param upstream: Mark all parents (upstream tasks) :param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags :param future: Mark all future tasks on the interval of the dag up until last execution date. :param past: Retroactively mark all tasks starting from start_date of the DAG :param state: State to which the tasks need to be set :param commit: Commit tasks to be altered to the database :param session: database session :return: list of tasks that have been created and updated """ if not tasks: return [] if not timezone.is_localized(execution_date): raise ValueError( "Received non-localized date {}".format(execution_date)) task_dags = {task.dag for task in tasks} if len(task_dags) > 1: raise ValueError( "Received tasks from multiple DAGs: {}".format(task_dags)) dag = next(iter(task_dags)) if dag is None: raise ValueError("Received tasks with no DAG") dates = get_execution_dates(dag, execution_date, future, past) task_ids = list(find_task_relatives(tasks, downstream, upstream)) confirmed_dates = verify_dag_run_integrity(dag, dates) sub_dag_run_ids = get_subdag_runs(dag, session, state, task_ids, commit, confirmed_dates) # now look for the task instances that are affected qry_dag = get_all_dag_task_query(dag, session, state, task_ids, confirmed_dates) if commit: tis_altered = qry_dag.with_for_update().all() if sub_dag_run_ids: qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates) tis_altered += qry_sub_dag.with_for_update().all() for task_instance in tis_altered: task_instance.state = state if state in State.finished(): task_instance.end_date = timezone.utcnow() task_instance.set_duration() else: tis_altered = qry_dag.all() if sub_dag_run_ids: qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates) tis_altered += qry_sub_dag.all() return tis_altered
def _trigger_dag( dag_id: str, dag_bag: DagBag, run_id: Optional[str] = None, conf: Optional[Union[dict, str]] = None, execution_date: Optional[datetime] = None, replace_microseconds: bool = True, ) -> List[Optional[DagRun]]: """Triggers DAG run. :param dag_id: DAG ID :param dag_bag: DAG Bag model :param run_id: ID of the dag_run :param conf: configuration :param execution_date: date of execution :param replace_microseconds: whether microseconds should be zeroed :return: list of triggered dags """ dag = dag_bag.get_dag(dag_id) # prefetch dag if it is stored serialized if dag is None or dag_id not in dag_bag.dags: raise DagNotFound(f"Dag id {dag_id} not found") execution_date = execution_date if execution_date else timezone.utcnow() if not timezone.is_localized(execution_date): raise ValueError("The execution_date should be localized") if replace_microseconds: execution_date = execution_date.replace(microsecond=0) if dag.default_args and 'start_date' in dag.default_args: min_dag_start_date = dag.default_args["start_date"] if min_dag_start_date and execution_date < min_dag_start_date: raise ValueError( f"The execution_date [{execution_date.isoformat()}] should be >= start_date " f"[{min_dag_start_date.isoformat()}] from DAG's default_args" ) logical_date = timezone.coerce_datetime(execution_date) data_interval = dag.timetable.infer_manual_data_interval(run_after=logical_date) run_id = run_id or dag.timetable.generate_run_id( run_type=DagRunType.MANUAL, logical_date=logical_date, data_interval=data_interval ) dag_run = DagRun.find_duplicate(dag_id=dag_id, execution_date=execution_date, run_id=run_id) if dag_run: raise DagRunAlreadyExists( f"A Dag Run already exists for dag id {dag_id} at {execution_date} with run id {run_id}" ) run_conf = None if conf: run_conf = conf if isinstance(conf, dict) else json.loads(conf) dag_runs = [] dags_to_run = [dag] + dag.subdags for _dag in dags_to_run: dag_run = _dag.create_dagrun( run_id=run_id, execution_date=execution_date, state=DagRunState.QUEUED, conf=run_conf, external_trigger=True, dag_hash=dag_bag.dags_hash.get(dag_id), data_interval=data_interval, ) dag_runs.append(dag_run) return dag_runs
def set_state( *, tasks: Union[Collection[Operator], Collection[Tuple[Operator, int]]], run_id: Optional[str] = None, execution_date: Optional[datetime] = None, upstream: bool = False, downstream: bool = False, future: bool = False, past: bool = False, state: TaskInstanceState = TaskInstanceState.SUCCESS, commit: bool = False, session: SASession = NEW_SESSION, ) -> List[TaskInstance]: """ Set the state of a task instance and if needed its relatives. Can set state for future tasks (calculated from run_id) and retroactively for past tasks. Will verify integrity of past dag runs in order to create tasks that did not exist. It will not create dag runs that are missing on the schedule (but it will as for subdag dag runs if needed). :param tasks: the iterable of tasks or (task, map_index) tuples from which to work. task.task.dag needs to be set :param run_id: the run_id of the dagrun to start looking from :param execution_date: the execution date from which to start looking(deprecated) :param upstream: Mark all parents (upstream tasks) :param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags :param future: Mark all future tasks on the interval of the dag up until last execution date. :param past: Retroactively mark all tasks starting from start_date of the DAG :param state: State to which the tasks need to be set :param commit: Commit tasks to be altered to the database :param session: database session :return: list of tasks that have been created and updated """ if not tasks: return [] if not exactly_one(execution_date, run_id): raise ValueError( "Exactly one of dag_run_id and execution_date must be set") if execution_date and not timezone.is_localized(execution_date): raise ValueError(f"Received non-localized date {execution_date}") task_dags = { task[0].dag if isinstance(task, tuple) else task.dag for task in tasks } if len(task_dags) > 1: raise ValueError(f"Received tasks from multiple DAGs: {task_dags}") dag = next(iter(task_dags)) if dag is None: raise ValueError("Received tasks with no DAG") if execution_date: run_id = dag.get_dagrun(execution_date=execution_date).run_id if not run_id: raise ValueError("Received tasks with no run_id") dag_run_ids = get_run_ids(dag, run_id, future, past) task_id_map_index_list = list( find_task_relatives(tasks, downstream, upstream)) task_ids = [task_id for task_id, _ in task_id_map_index_list] # check if task_id_map_index_list contains map_index of None # if it contains None, there was no map_index supplied for the task for _, index in task_id_map_index_list: if index is None: task_id_map_index_list = [ task_id for task_id, _ in task_id_map_index_list ] break confirmed_infos = list(_iter_existing_dag_run_infos(dag, dag_run_ids)) confirmed_dates = [info.logical_date for info in confirmed_infos] sub_dag_run_ids = list( _iter_subdag_run_ids(dag, session, DagRunState(state), task_ids, commit, confirmed_infos), ) # now look for the task instances that are affected qry_dag = get_all_dag_task_query(dag, session, state, task_id_map_index_list, confirmed_dates) if commit: tis_altered = qry_dag.with_for_update().all() if sub_dag_run_ids: qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates) tis_altered += qry_sub_dag.with_for_update().all() for task_instance in tis_altered: task_instance.set_state(state) else: tis_altered = qry_dag.all() if sub_dag_run_ids: qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates) tis_altered += qry_sub_dag.all() return tis_altered
def date_range( start_date: datetime, end_date: Optional[datetime] = None, num: Optional[int] = None, delta: Optional[Union[str, timedelta, relativedelta]] = None, ) -> List[datetime]: """ Get a set of dates as a list based on a start, end and delta, delta can be something that can be added to `datetime.datetime` or a cron expression as a `str` .. code-block:: python date_range(datetime(2016, 1, 1), datetime(2016, 1, 3), delta=timedelta(1)) [datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2016, 1, 2, 0, 0), datetime.datetime(2016, 1, 3, 0, 0)] date_range(datetime(2016, 1, 1), datetime(2016, 1, 3), delta='0 0 * * *') [datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2016, 1, 2, 0, 0), datetime.datetime(2016, 1, 3, 0, 0)] date_range(datetime(2016, 1, 1), datetime(2016, 3, 3), delta="0 0 0 * *") [datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2016, 2, 1, 0, 0), datetime.datetime(2016, 3, 1, 0, 0)] :param start_date: anchor date to start the series from :type start_date: datetime.datetime :param end_date: right boundary for the date range :type end_date: datetime.datetime :param num: alternatively to end_date, you can specify the number of number of entries you want in the range. This number can be negative, output will always be sorted regardless :type num: int :param delta: step length. It can be datetime.timedelta or cron expression as string :type delta: datetime.timedelta or str or dateutil.relativedelta """ if not delta: return [] if end_date: if start_date > end_date: raise Exception("Wait. start_date needs to be before end_date") if num: raise Exception("Wait. Either specify end_date OR num") if not end_date and not num: end_date = timezone.utcnow() delta_iscron = False time_zone = start_date.tzinfo abs_delta: Union[timedelta, relativedelta] if isinstance(delta, str): delta_iscron = True if timezone.is_localized(start_date): start_date = timezone.make_naive(start_date, time_zone) cron = croniter(cron_presets.get(delta, delta), start_date) elif isinstance(delta, timedelta): abs_delta = abs(delta) elif isinstance(delta, relativedelta): abs_delta = abs(delta) else: raise Exception("Wait. delta must be either datetime.timedelta or cron expression as str") dates = [] if end_date: if timezone.is_naive(start_date) and not timezone.is_naive(end_date): end_date = timezone.make_naive(end_date, time_zone) while start_date <= end_date: # type: ignore if timezone.is_naive(start_date): dates.append(timezone.make_aware(start_date, time_zone)) else: dates.append(start_date) if delta_iscron: start_date = cron.get_next(datetime) else: start_date += abs_delta else: num_entries: int = num # type: ignore for _ in range(abs(num_entries)): if timezone.is_naive(start_date): dates.append(timezone.make_aware(start_date, time_zone)) else: dates.append(start_date) if delta_iscron and num_entries > 0: start_date = cron.get_next(datetime) elif delta_iscron: start_date = cron.get_prev(datetime) elif num_entries > 0: start_date += abs_delta else: start_date -= abs_delta return sorted(dates)
def _trigger_dag( dag_id: str, dag_bag: DagBag, run_id: Optional[str] = None, conf: Optional[Union[dict, str]] = None, execution_date: Optional[datetime] = None, replace_microseconds: bool = True, ) -> List[DagRun]: """Triggers DAG run. :param dag_id: DAG ID :param dag_bag: DAG Bag model :param run_id: ID of the dag_run :param conf: configuration :param execution_date: date of execution :param replace_microseconds: whether microseconds should be zeroed :return: list of triggered dags """ dag = dag_bag.get_dag(dag_id) # prefetch dag if it is stored serialized if dag_id not in dag_bag.dags: raise DagNotFound(f"Dag id {dag_id} not found") execution_date = execution_date if execution_date else timezone.utcnow() if not timezone.is_localized(execution_date): raise ValueError("The execution_date should be localized") if replace_microseconds: execution_date = execution_date.replace(microsecond=0) if dag.default_args and 'start_date' in dag.default_args: min_dag_start_date = dag.default_args["start_date"] if min_dag_start_date and execution_date < min_dag_start_date: raise ValueError( "The execution_date [{}] should be >= start_date [{}] from DAG's default_args" .format(execution_date.isoformat(), min_dag_start_date.isoformat())) run_id = run_id or DagRun.generate_run_id(DagRunType.MANUAL, execution_date) dag_run = DagRun.find(dag_id=dag_id, run_id=run_id) if dag_run: raise DagRunAlreadyExists( f"Run id {run_id} already exists for dag id {dag_id}") run_conf = None if conf: run_conf = conf if isinstance(conf, dict) else json.loads(conf) triggers = [] dags_to_trigger = [dag] + dag.subdags for _dag in dags_to_trigger: trigger = _dag.create_dagrun( run_id=run_id, execution_date=execution_date, state=State.RUNNING, conf=run_conf, external_trigger=True, dag_hash=dag_bag.dags_hash.get(dag_id), ) triggers.append(trigger) return triggers
def test_utcnow(self): now = timezone.utcnow() assert timezone.is_localized(now) assert now.replace(tzinfo=None) == now.astimezone(UTC).replace(tzinfo=None)
def test_is_aware(self): assert timezone.is_localized(datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT)) assert not timezone.is_localized(datetime.datetime(2011, 9, 1, 13, 20, 30))
def _trigger_dag( dag_id, # type: str dag_bag, # type: DagBag dag_run, # type: DagModel run_id, # type: Optional[str] conf, # type: Optional[Union[dict, str]] execution_date, # type: Optional[datetime] replace_microseconds, # type: bool ): # pylint: disable=too-many-arguments # type: (...) -> List[DagRun] """Triggers DAG run. :param dag_id: DAG ID :param dag_bag: DAG Bag model :param dag_run: DAG Run model :param run_id: ID of the dag_run :param conf: configuration :param execution_date: date of execution :param replace_microseconds: whether microseconds should be zeroed :return: list of triggered dags """ dag = dag_bag.get_dag(dag_id) # prefetch dag if it is stored serialized if dag_id not in dag_bag.dags: raise DagNotFound("Dag id {} not found".format(dag_id)) execution_date = execution_date if execution_date else timezone.utcnow() assert timezone.is_localized(execution_date) if replace_microseconds: execution_date = execution_date.replace(microsecond=0) if dag.default_args and 'start_date' in dag.default_args: min_dag_start_date = dag.default_args["start_date"] if min_dag_start_date and execution_date < min_dag_start_date: raise ValueError( "The execution_date [{0}] should be >= start_date [{1}] from DAG's default_args".format( execution_date.isoformat(), min_dag_start_date.isoformat())) if not run_id: run_id = "manual__{0}".format(execution_date.isoformat()) dag_run_id = dag_run.find(dag_id=dag_id, run_id=run_id) if dag_run_id: raise DagRunAlreadyExists("Run id {} already exists for dag id {}".format( run_id, dag_id )) run_conf = None if conf: if isinstance(conf, dict): run_conf = conf else: run_conf = json.loads(conf) triggers = [] dags_to_trigger = [dag] + dag.subdags for _dag in dags_to_trigger: trigger = _dag.create_dagrun( run_id=run_id, execution_date=execution_date, state=State.RUNNING, conf=run_conf, external_trigger=True, ) triggers.append(trigger) return triggers
def test_utcnow(self): now = timezone.utcnow() self.assertTrue(timezone.is_localized(now)) self.assertEquals(now.replace(tzinfo=None), now.astimezone(UTC).replace(tzinfo=None))
def test_is_aware(self): self.assertTrue( timezone.is_localized( datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT))) self.assertFalse( timezone.is_localized(datetime.datetime(2011, 9, 1, 13, 20, 30)))
def set_state(task, execution_date, upstream=False, downstream=False, future=False, past=False, state=State.SUCCESS, commit=False): """ Set the state of a task instance and if needed its relatives. Can set state for future tasks (calculated from execution_date) and retroactively for past tasks. Will verify integrity of past dag runs in order to create tasks that did not exist. It will not create dag runs that are missing on the schedule (but it will as for subdag dag runs if needed). :param task: the task from which to work. task.task.dag needs to be set :param execution_date: the execution date from which to start looking :param upstream: Mark all parents (upstream tasks) :param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags :param future: Mark all future tasks on the interval of the dag up until last execution date. :param past: Retroactively mark all tasks starting from start_date of the DAG :param state: State to which the tasks need to be set :param commit: Commit tasks to be altered to the database :return: list of tasks that have been created and updated """ assert timezone.is_localized(execution_date) # microseconds are supported by the database, but is not handled # correctly by airflow on e.g. the filesystem and in other places execution_date = execution_date.replace(microsecond=0) assert task.dag is not None dag = task.dag latest_execution_date = dag.latest_execution_date assert latest_execution_date is not None # determine date range of dag runs and tasks to consider end_date = latest_execution_date if future else execution_date if 'start_date' in dag.default_args: start_date = dag.default_args['start_date'] elif dag.start_date: start_date = dag.start_date else: start_date = execution_date start_date = execution_date if not past else start_date if dag.schedule_interval == '@once': dates = [start_date] else: dates = dag.date_range(start_date=start_date, end_date=end_date) # find relatives (siblings = downstream, parents = upstream) if needed task_ids = [task.task_id] if downstream: relatives = task.get_flat_relatives(upstream=False) task_ids += [t.task_id for t in relatives] if upstream: relatives = task.get_flat_relatives(upstream=True) task_ids += [t.task_id for t in relatives] # verify the integrity of the dag runs in case a task was added or removed # set the confirmed execution dates as they might be different # from what was provided confirmed_dates = [] drs = DagRun.find(dag_id=dag.dag_id, execution_date=dates) for dr in drs: dr.dag = dag dr.verify_integrity() confirmed_dates.append(dr.execution_date) # go through subdagoperators and create dag runs. We will only work # within the scope of the subdag. We wont propagate to the parent dag, # but we will propagate from parent to subdag. session = Session() dags = [dag] sub_dag_ids = [] while len(dags) > 0: current_dag = dags.pop() for task_id in task_ids: if not current_dag.has_task(task_id): continue current_task = current_dag.get_task(task_id) if isinstance(current_task, SubDagOperator): # this works as a kind of integrity check # it creates missing dag runs for subdagoperators, # maybe this should be moved to dagrun.verify_integrity drs = _create_dagruns( current_task.subdag, execution_dates=confirmed_dates, state=State.RUNNING, run_id_template=BackfillJob.ID_FORMAT_PREFIX) for dr in drs: dr.dag = current_task.subdag dr.verify_integrity() if commit: dr.state = state session.merge(dr) dags.append(current_task.subdag) sub_dag_ids.append(current_task.subdag.dag_id) # now look for the task instances that are affected TI = TaskInstance # get all tasks of the main dag that will be affected by a state change qry_dag = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.execution_date.in_(confirmed_dates), TI.task_id.in_(task_ids)).filter( or_(TI.state.is_(None), TI.state != state)) # get *all* tasks of the sub dags if len(sub_dag_ids) > 0: qry_sub_dag = session.query(TI).filter( TI.dag_id.in_(sub_dag_ids), TI.execution_date.in_(confirmed_dates)).filter( or_(TI.state.is_(None), TI.state != state)) if commit: tis_altered = qry_dag.with_for_update().all() if len(sub_dag_ids) > 0: tis_altered += qry_sub_dag.with_for_update().all() for ti in tis_altered: ti.state = state session.commit() else: tis_altered = qry_dag.all() if len(sub_dag_ids) > 0: tis_altered += qry_sub_dag.all() session.expunge_all() session.close() return tis_altered
def set_state(task, execution_date, upstream=False, downstream=False, future=False, past=False, state=State.SUCCESS, commit=False): """ Set the state of a task instance and if needed its relatives. Can set state for future tasks (calculated from execution_date) and retroactively for past tasks. Will verify integrity of past dag runs in order to create tasks that did not exist. It will not create dag runs that are missing on the schedule (but it will as for subdag dag runs if needed). :param task: the task from which to work. task.task.dag needs to be set :param execution_date: the execution date from which to start looking :param upstream: Mark all parents (upstream tasks) :param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags :param future: Mark all future tasks on the interval of the dag up until last execution date. :param past: Retroactively mark all tasks starting from start_date of the DAG :param state: State to which the tasks need to be set :param commit: Commit tasks to be altered to the database :return: list of tasks that have been created and updated """ assert timezone.is_localized(execution_date) # microseconds are supported by the database, but is not handled # correctly by airflow on e.g. the filesystem and in other places execution_date = execution_date.replace(microsecond=0) assert task.dag is not None dag = task.dag latest_execution_date = dag.latest_execution_date assert latest_execution_date is not None # determine date range of dag runs and tasks to consider end_date = latest_execution_date if future else execution_date if 'start_date' in dag.default_args: start_date = dag.default_args['start_date'] elif dag.start_date: start_date = dag.start_date else: start_date = execution_date start_date = execution_date if not past else start_date if dag.schedule_interval == '@once': dates = [start_date] else: dates = dag.date_range(start_date=start_date, end_date=end_date) # find relatives (siblings = downstream, parents = upstream) if needed task_ids = [task.task_id] if downstream: relatives = task.get_flat_relatives(upstream=False) task_ids += [t.task_id for t in relatives] if upstream: relatives = task.get_flat_relatives(upstream=True) task_ids += [t.task_id for t in relatives] # verify the integrity of the dag runs in case a task was added or removed # set the confirmed execution dates as they might be different # from what was provided confirmed_dates = [] drs = DagRun.find(dag_id=dag.dag_id, execution_date=dates) for dr in drs: dr.dag = dag dr.verify_integrity() confirmed_dates.append(dr.execution_date) # go through subdagoperators and create dag runs. We will only work # within the scope of the subdag. We wont propagate to the parent dag, # but we will propagate from parent to subdag. session = Session() dags = [dag] sub_dag_ids = [] while len(dags) > 0: current_dag = dags.pop() for task_id in task_ids: if not current_dag.has_task(task_id): continue current_task = current_dag.get_task(task_id) if isinstance(current_task, SubDagOperator): # this works as a kind of integrity check # it creates missing dag runs for subdagoperators, # maybe this should be moved to dagrun.verify_integrity drs = _create_dagruns(current_task.subdag, execution_dates=confirmed_dates, state=State.RUNNING, run_id_template=BackfillJob.ID_FORMAT_PREFIX) for dr in drs: dr.dag = current_task.subdag dr.verify_integrity() if commit: dr.state = state session.merge(dr) dags.append(current_task.subdag) sub_dag_ids.append(current_task.subdag.dag_id) # now look for the task instances that are affected TI = TaskInstance # get all tasks of the main dag that will be affected by a state change qry_dag = session.query(TI).filter( TI.dag_id==dag.dag_id, TI.execution_date.in_(confirmed_dates), TI.task_id.in_(task_ids)).filter( or_(TI.state.is_(None), TI.state != state) ) # get *all* tasks of the sub dags if len(sub_dag_ids) > 0: qry_sub_dag = session.query(TI).filter( TI.dag_id.in_(sub_dag_ids), TI.execution_date.in_(confirmed_dates)).filter( or_(TI.state.is_(None), TI.state != state) ) if commit: tis_altered = qry_dag.with_for_update().all() if len(sub_dag_ids) > 0: tis_altered += qry_sub_dag.with_for_update().all() for ti in tis_altered: ti.state = state session.commit() else: tis_altered = qry_dag.all() if len(sub_dag_ids) > 0: tis_altered += qry_sub_dag.all() session.expunge_all() session.close() return tis_altered
def set_dag_run_state_to_failed( *, dag: DAG, execution_date: Optional[datetime] = None, run_id: Optional[str] = None, commit: bool = False, session: SASession = NEW_SESSION, ) -> List[TaskInstance]: """ Set the dag run for a specific execution date or run_id and its running task instances to failed. :param dag: the DAG of which to alter state :param execution_date: the execution date from which to start looking(deprecated) :param run_id: the DAG run_id to start looking from :param commit: commit DAG and tasks to be altered to the database :param session: database session :return: If commit is true, list of tasks that have been updated, otherwise list of tasks that will be updated :raises: AssertionError if dag or execution_date is invalid """ if not exactly_one(execution_date, run_id): return [] if not dag: return [] if execution_date: if not timezone.is_localized(execution_date): raise ValueError(f"Received non-localized date {execution_date}") dag_run = dag.get_dagrun(execution_date=execution_date) if not dag_run: raise ValueError( f'DagRun with execution_date: {execution_date} not found') run_id = dag_run.run_id if not run_id: raise ValueError(f'Invalid dag_run_id: {run_id}') # Mark the dag run to failed. if commit: _set_dag_run_state(dag.dag_id, run_id, DagRunState.FAILED, session) # Mark only RUNNING task instances. task_ids = [task.task_id for task in dag.tasks] tis = session.query(TaskInstance).filter( TaskInstance.dag_id == dag.dag_id, TaskInstance.run_id == run_id, TaskInstance.task_id.in_(task_ids), TaskInstance.state.in_(State.running), ) task_ids_of_running_tis = [task_instance.task_id for task_instance in tis] tasks = [] for task in dag.tasks: if task.task_id not in task_ids_of_running_tis: continue task.dag = dag tasks.append(task) # Mark non-finished tasks as SKIPPED. tis = session.query(TaskInstance).filter( TaskInstance.dag_id == dag.dag_id, TaskInstance.run_id == run_id, TaskInstance.state.not_in(State.finished), TaskInstance.state.not_in(State.running), ) tis = [ti for ti in tis] if commit: for ti in tis: ti.set_state(State.SKIPPED) return tis + set_state(tasks=tasks, run_id=run_id, state=State.FAILED, commit=commit, session=session)