コード例 #1
0
ファイル: test_backfill_job.py プロジェクト: shivamx/airflow
    def test_sub_set_subdag(self):
        dag = DAG('test_sub_set_subdag',
                  start_date=DEFAULT_DATE,
                  default_args={'owner': 'owner1'})

        with dag:
            op1 = DummyOperator(task_id='leave1')
            op2 = DummyOperator(task_id='leave2')
            op3 = DummyOperator(task_id='upstream_level_1')
            op4 = DummyOperator(task_id='upstream_level_2')
            op5 = DummyOperator(task_id='upstream_level_3')
            # order randomly
            op2.set_downstream(op3)
            op1.set_downstream(op3)
            op4.set_downstream(op5)
            op3.set_downstream(op4)

        dag.clear()
        dr = dag.create_dagrun(run_id="test",
                               state=State.RUNNING,
                               execution_date=DEFAULT_DATE,
                               start_date=DEFAULT_DATE)

        executor = MockExecutor()
        sub_dag = dag.sub_dag(task_regex="leave*",
                              include_downstream=False,
                              include_upstream=False)
        job = BackfillJob(dag=sub_dag,
                          start_date=DEFAULT_DATE,
                          end_date=DEFAULT_DATE,
                          executor=executor)
        job.run()

        self.assertRaises(sqlalchemy.orm.exc.NoResultFound, dr.refresh_from_db)
        # the run_id should have changed, so a refresh won't work
        drs = DagRun.find(dag_id=dag.dag_id, execution_date=DEFAULT_DATE)
        dr = drs[0]

        self.assertEqual(
            BackfillJob.ID_FORMAT_PREFIX.format(DEFAULT_DATE.isoformat()),
            dr.run_id)
        for ti in dr.get_task_instances():
            if ti.task_id == 'leave1' or ti.task_id == 'leave2':
                self.assertEqual(State.SUCCESS, ti.state)
            else:
                self.assertEqual(State.NONE, ti.state)
コード例 #2
0
    def restart_failed_task(self):
        """Restart the failed task in the specified dag run.
        According to dag_id, run_id get dag_run from session,
        query task_instances that status is FAILED in dag_run,
        restart them and clear status of all task_instance's downstream of them.

        args:
            dag_id: dag id
            run_id: the run id of dag run
        """
        logging.info("Executing custom 'restart_failed_task' function")

        dagbag = self.get_dagbag()

        dag_id = self.get_argument(request, 'dag_id')
        run_id = self.get_argument(request, 'run_id')

        session = settings.Session()
        query = session.query(DagRun)
        dag_run = query.filter(
            DagRun.dag_id == dag_id,
            DagRun.run_id == run_id
        ).first()

        if dag_run is None:
            return ApiResponse.not_found("dag run is not found")

        if dag_id not in dagbag.dags:
            return ApiResponse.bad_request("Dag id {} not found".format(dag_id))

        dag = dagbag.get_dag(dag_id)

        if dag is None:
            return ApiResponse.not_found("dag is not found")

        tis = DagRun.get_task_instances(dag_run, State.FAILED)
        logging.info('task_instances: ' + str(tis))

        failed_task_count = len(tis)
        if failed_task_count > 0:
            for ti in tis:
                dag = DAG.sub_dag(
                    self=dag,
                    task_regex=r"^{0}$".format(ti.task_id),
                    include_downstream=True,
                    include_upstream=False)

                count = DAG.clear(
                    self=dag,
                    start_date=dag_run.execution_date,
                    end_date=dag_run.execution_date,
                )
                logging.info('count:' + str(count))
        else:
            return ApiResponse.not_found("dagRun don't have failed tasks")

        session.close()

        return ApiResponse.success({
            'failed_task_count': failed_task_count,
            'clear_task_count': count
        })