def test_sub_set_subdag(self): dag = DAG('test_sub_set_subdag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: op1 = DummyOperator(task_id='leave1') op2 = DummyOperator(task_id='leave2') op3 = DummyOperator(task_id='upstream_level_1') op4 = DummyOperator(task_id='upstream_level_2') op5 = DummyOperator(task_id='upstream_level_3') # order randomly op2.set_downstream(op3) op1.set_downstream(op3) op4.set_downstream(op5) op3.set_downstream(op4) dag.clear() dr = dag.create_dagrun(run_id="test", state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE) executor = MockExecutor() sub_dag = dag.sub_dag(task_regex="leave*", include_downstream=False, include_upstream=False) job = BackfillJob(dag=sub_dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor) job.run() self.assertRaises(sqlalchemy.orm.exc.NoResultFound, dr.refresh_from_db) # the run_id should have changed, so a refresh won't work drs = DagRun.find(dag_id=dag.dag_id, execution_date=DEFAULT_DATE) dr = drs[0] self.assertEqual( BackfillJob.ID_FORMAT_PREFIX.format(DEFAULT_DATE.isoformat()), dr.run_id) for ti in dr.get_task_instances(): if ti.task_id == 'leave1' or ti.task_id == 'leave2': self.assertEqual(State.SUCCESS, ti.state) else: self.assertEqual(State.NONE, ti.state)
def restart_failed_task(self): """Restart the failed task in the specified dag run. According to dag_id, run_id get dag_run from session, query task_instances that status is FAILED in dag_run, restart them and clear status of all task_instance's downstream of them. args: dag_id: dag id run_id: the run id of dag run """ logging.info("Executing custom 'restart_failed_task' function") dagbag = self.get_dagbag() dag_id = self.get_argument(request, 'dag_id') run_id = self.get_argument(request, 'run_id') session = settings.Session() query = session.query(DagRun) dag_run = query.filter( DagRun.dag_id == dag_id, DagRun.run_id == run_id ).first() if dag_run is None: return ApiResponse.not_found("dag run is not found") if dag_id not in dagbag.dags: return ApiResponse.bad_request("Dag id {} not found".format(dag_id)) dag = dagbag.get_dag(dag_id) if dag is None: return ApiResponse.not_found("dag is not found") tis = DagRun.get_task_instances(dag_run, State.FAILED) logging.info('task_instances: ' + str(tis)) failed_task_count = len(tis) if failed_task_count > 0: for ti in tis: dag = DAG.sub_dag( self=dag, task_regex=r"^{0}$".format(ti.task_id), include_downstream=True, include_upstream=False) count = DAG.clear( self=dag, start_date=dag_run.execution_date, end_date=dag_run.execution_date, ) logging.info('count:' + str(count)) else: return ApiResponse.not_found("dagRun don't have failed tasks") session.close() return ApiResponse.success({ 'failed_task_count': failed_task_count, 'clear_task_count': count })