class TestMarkTasks(unittest.TestCase): def setUp(self): self.dagbag = models.DagBag(include_examples=True) self.dag1 = self.dagbag.dags['example_bash_operator'] self.dag2 = self.dagbag.dags['example_subdag_operator'] self.execution_dates = [days_ago(2), days_ago(1)] drs = _create_dagruns(self.dag1, self.execution_dates, state=State.RUNNING, run_id_template="scheduled__{}") for dr in drs: dr.dag = self.dag1 dr.verify_integrity() drs = _create_dagruns(self.dag2, [self.dag2.default_args['start_date']], state=State.RUNNING, run_id_template="scheduled__{}") for dr in drs: dr.dag = self.dag2 dr.verify_integrity() self.session = Session() def tearDown(self): self.dag1.clear() self.dag2.clear() # just to make sure we are fully cleaned up self.session.query(models.DagRun).delete() self.session.query(models.TaskInstance).delete() self.session.commit() self.session.close() def snapshot_state(self, dag, execution_dates): TI = models.TaskInstance tis = self.session.query(TI).filter( TI.dag_id == dag.dag_id, TI.execution_date.in_(execution_dates) ).all() self.session.expunge_all() return tis def verify_state(self, dag, task_ids, execution_dates, state, old_tis): TI = models.TaskInstance tis = self.session.query(TI).filter( TI.dag_id == dag.dag_id, TI.execution_date.in_(execution_dates) ).all() self.assertTrue(len(tis) > 0) for ti in tis: if ti.task_id in task_ids and ti.execution_date in execution_dates: self.assertEqual(ti.state, state) else: for old_ti in old_tis: if old_ti.task_id == ti.task_id and old_ti.execution_date == ti.execution_date: self.assertEqual(ti.state, old_ti.state) def test_mark_tasks_now(self): # set one task to success but do not commit snapshot = self.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("runme_1") altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=False, future=False, past=False, state=State.SUCCESS, commit=False) self.assertEqual(len(altered), 1) self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], None, snapshot) # set one and only one task to success altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=False, future=False, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 1) self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.SUCCESS, snapshot) # set no tasks altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=False, future=False, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 0) self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.SUCCESS, snapshot) # set task to other than success altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=False, future=False, past=False, state=State.FAILED, commit=True) self.assertEqual(len(altered), 1) self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.FAILED, snapshot) # dont alter other tasks snapshot = self.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("runme_0") altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=False, future=False, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 1) self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.SUCCESS, snapshot) def test_mark_downstream(self): # test downstream snapshot = self.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("runme_1") relatives = task.get_flat_relatives(upstream=False) task_ids = [t.task_id for t in relatives] task_ids.append(task.task_id) altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=True, future=False, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 3) self.verify_state(self.dag1, task_ids, [self.execution_dates[0]], State.SUCCESS, snapshot) def test_mark_upstream(self): # test upstream snapshot = self.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("run_after_loop") relatives = task.get_flat_relatives(upstream=True) task_ids = [t.task_id for t in relatives] task_ids.append(task.task_id) altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=True, downstream=False, future=False, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 4) self.verify_state(self.dag1, task_ids, [self.execution_dates[0]], State.SUCCESS, snapshot) def test_mark_tasks_future(self): # set one task to success towards end of scheduled dag runs snapshot = self.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("runme_1") altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=False, future=True, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 2) self.verify_state(self.dag1, [task.task_id], self.execution_dates, State.SUCCESS, snapshot) def test_mark_tasks_past(self): # set one task to success towards end of scheduled dag runs snapshot = self.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("runme_1") altered = set_state(task=task, execution_date=self.execution_dates[1], upstream=False, downstream=False, future=False, past=True, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 2) self.verify_state(self.dag1, [task.task_id], self.execution_dates, State.SUCCESS, snapshot) # TODO: this skipIf should be removed once a fixing solution is found later # We skip it here because this test case is working with Postgres & SQLite # but not with MySQL @unittest.skipIf('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), "Flaky with MySQL") def test_mark_tasks_subdag(self): # set one task to success towards end of scheduled dag runs task = self.dag2.get_task("section-1") relatives = task.get_flat_relatives(upstream=False) task_ids = [t.task_id for t in relatives] task_ids.append(task.task_id) altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=True, future=False, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 14) # cannot use snapshot here as that will require drilling down the # the sub dag tree essentially recreating the same code as in the # tested logic. self.verify_state(self.dag2, task_ids, [self.execution_dates[0]], State.SUCCESS, [])
def set_state(task, execution_date, upstream=False, downstream=False, future=False, past=False, state=State.SUCCESS, commit=False): """ Set the state of a task instance and if needed its relatives. Can set state for future tasks (calculated from execution_date) and retroactively for past tasks. Will verify integrity of past dag runs in order to create tasks that did not exist. It will not create dag runs that are missing on the schedule (but it will as for subdag dag runs if needed). :param task: the task from which to work. task.task.dag needs to be set :param execution_date: the execution date from which to start looking :param upstream: Mark all parents (upstream tasks) :param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags :param future: Mark all future tasks on the interval of the dag up until last execution date. :param past: Retroactively mark all tasks starting from start_date of the DAG :param state: State to which the tasks need to be set :param commit: Commit tasks to be altered to the database :return: list of tasks that have been created and updated """ assert timezone.is_localized(execution_date) # microseconds are supported by the database, but is not handled # correctly by airflow on e.g. the filesystem and in other places execution_date = execution_date.replace(microsecond=0) assert task.dag is not None dag = task.dag latest_execution_date = dag.latest_execution_date assert latest_execution_date is not None # determine date range of dag runs and tasks to consider end_date = latest_execution_date if future else execution_date if 'start_date' in dag.default_args: start_date = dag.default_args['start_date'] elif dag.start_date: start_date = dag.start_date else: start_date = execution_date start_date = execution_date if not past else start_date if dag.schedule_interval == '@once': dates = [start_date] else: dates = dag.date_range(start_date=start_date, end_date=end_date) # find relatives (siblings = downstream, parents = upstream) if needed task_ids = [task.task_id] if downstream: relatives = task.get_flat_relatives(upstream=False) task_ids += [t.task_id for t in relatives] if upstream: relatives = task.get_flat_relatives(upstream=True) task_ids += [t.task_id for t in relatives] # verify the integrity of the dag runs in case a task was added or removed # set the confirmed execution dates as they might be different # from what was provided confirmed_dates = [] drs = DagRun.find(dag_id=dag.dag_id, execution_date=dates) for dr in drs: dr.dag = dag dr.verify_integrity() confirmed_dates.append(dr.execution_date) # go through subdagoperators and create dag runs. We will only work # within the scope of the subdag. We wont propagate to the parent dag, # but we will propagate from parent to subdag. session = Session() dags = [dag] sub_dag_ids = [] while len(dags) > 0: current_dag = dags.pop() for task_id in task_ids: if not current_dag.has_task(task_id): continue current_task = current_dag.get_task(task_id) if isinstance(current_task, SubDagOperator): # this works as a kind of integrity check # it creates missing dag runs for subdagoperators, # maybe this should be moved to dagrun.verify_integrity drs = _create_dagruns(current_task.subdag, execution_dates=confirmed_dates, state=State.RUNNING, run_id_template=BackfillJob.ID_FORMAT_PREFIX) for dr in drs: dr.dag = current_task.subdag dr.verify_integrity() if commit: dr.state = state session.merge(dr) dags.append(current_task.subdag) sub_dag_ids.append(current_task.subdag.dag_id) # now look for the task instances that are affected TI = TaskInstance # get all tasks of the main dag that will be affected by a state change qry_dag = session.query(TI).filter( TI.dag_id==dag.dag_id, TI.execution_date.in_(confirmed_dates), TI.task_id.in_(task_ids)).filter( or_(TI.state.is_(None), TI.state != state) ) # get *all* tasks of the sub dags if len(sub_dag_ids) > 0: qry_sub_dag = session.query(TI).filter( TI.dag_id.in_(sub_dag_ids), TI.execution_date.in_(confirmed_dates)).filter( or_(TI.state.is_(None), TI.state != state) ) if commit: tis_altered = qry_dag.with_for_update().all() if len(sub_dag_ids) > 0: tis_altered += qry_sub_dag.with_for_update().all() for ti in tis_altered: ti.state = state session.commit() else: tis_altered = qry_dag.all() if len(sub_dag_ids) > 0: tis_altered += qry_sub_dag.all() session.expunge_all() session.close() return tis_altered
class TestMarkTasks(unittest.TestCase): def setUp(self): self.dagbag = models.DagBag(include_examples=True) self.dag1 = self.dagbag.dags['test_example_bash_operator'] self.dag2 = self.dagbag.dags['example_subdag_operator'] self.execution_dates = [days_ago(2), days_ago(1)] drs = _create_dagruns(self.dag1, self.execution_dates, state=State.RUNNING, run_id_template="scheduled__{}") for dr in drs: dr.dag = self.dag1 dr.verify_integrity() drs = _create_dagruns(self.dag2, [self.dag2.default_args['start_date']], state=State.RUNNING, run_id_template="scheduled__{}") for dr in drs: dr.dag = self.dag2 dr.verify_integrity() self.session = Session() def snapshot_state(self, dag, execution_dates): TI = models.TaskInstance tis = self.session.query(TI).filter( TI.dag_id==dag.dag_id, TI.execution_date.in_(execution_dates) ).all() self.session.expunge_all() return tis def verify_state(self, dag, task_ids, execution_dates, state, old_tis): TI = models.TaskInstance tis = self.session.query(TI).filter( TI.dag_id==dag.dag_id, TI.execution_date.in_(execution_dates) ).all() self.assertTrue(len(tis) > 0) for ti in tis: if ti.task_id in task_ids and ti.execution_date in execution_dates: self.assertEqual(ti.state, state) else: for old_ti in old_tis: if (old_ti.task_id == ti.task_id and old_ti.execution_date == ti.execution_date): self.assertEqual(ti.state, old_ti.state) def test_mark_tasks_now(self): # set one task to success but do not commit snapshot = self.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("runme_1") altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=False, future=False, past=False, state=State.SUCCESS, commit=False) self.assertEqual(len(altered), 1) self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], None, snapshot) # set one and only one task to success altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=False, future=False, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 1) self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.SUCCESS, snapshot) # set no tasks altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=False, future=False, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 0) self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.SUCCESS, snapshot) # set task to other than success altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=False, future=False, past=False, state=State.FAILED, commit=True) self.assertEqual(len(altered), 1) self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.FAILED, snapshot) # dont alter other tasks snapshot = self.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("runme_0") altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=False, future=False, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 1) self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.SUCCESS, snapshot) def test_mark_downstream(self): # test downstream snapshot = self.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("runme_1") relatives = task.get_flat_relatives(upstream=False) task_ids = [t.task_id for t in relatives] task_ids.append(task.task_id) altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=True, future=False, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 3) self.verify_state(self.dag1, task_ids, [self.execution_dates[0]], State.SUCCESS, snapshot) def test_mark_upstream(self): # test upstream snapshot = self.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("run_after_loop") relatives = task.get_flat_relatives(upstream=True) task_ids = [t.task_id for t in relatives] task_ids.append(task.task_id) altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=True, downstream=False, future=False, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 4) self.verify_state(self.dag1, task_ids, [self.execution_dates[0]], State.SUCCESS, snapshot) def test_mark_tasks_future(self): # set one task to success towards end of scheduled dag runs snapshot = self.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("runme_1") altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=False, future=True, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 2) self.verify_state(self.dag1, [task.task_id], self.execution_dates, State.SUCCESS, snapshot) def test_mark_tasks_past(self): # set one task to success towards end of scheduled dag runs snapshot = self.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("runme_1") altered = set_state(task=task, execution_date=self.execution_dates[1], upstream=False, downstream=False, future=False, past=True, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 2) self.verify_state(self.dag1, [task.task_id], self.execution_dates, State.SUCCESS, snapshot) def test_mark_tasks_subdag(self): # set one task to success towards end of scheduled dag runs task = self.dag2.get_task("section-1") relatives = task.get_flat_relatives(upstream=False) task_ids = [t.task_id for t in relatives] task_ids.append(task.task_id) altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=True, future=False, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 14) # cannot use snapshot here as that will require drilling down the # the sub dag tree essentially recreating the same code as in the # tested logic. self.verify_state(self.dag2, task_ids, [self.execution_dates[0]], State.SUCCESS, []) def tearDown(self): self.dag1.clear() self.dag2.clear() # just to make sure we are fully cleaned up self.session.query(models.DagRun).delete() self.session.query(models.TaskInstance).delete() self.session.commit() self.session.close()