def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Clear task instances.""" body = request.get_json() try: data = clear_task_instance_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) dag = current_app.dag_bag.get_dag(dag_id) if not dag: error_message = f"Dag id {dag_id} not found" raise NotFound(error_message) reset_dag_runs = data.pop('reset_dag_runs') dry_run = data.pop('dry_run') # We always pass dry_run here, otherwise this would try to confirm on the terminal! task_instances = dag.clear(dry_run=True, dag_bag=current_app.dag_bag, **data) if not dry_run: clear_task_instances( task_instances.all(), session, dag=dag, dag_run_state=DagRunState.QUEUED if reset_dag_runs else False, ) return task_instance_reference_collection_schema.dump( TaskInstanceReferenceCollection(task_instances=task_instances.all()))
def post_clear_task_instances(dag_id: str, session=None): """Clear task instances.""" body = request.get_json() try: data = clear_task_instance_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) dag = current_app.dag_bag.get_dag(dag_id) if not dag: error_message = f"Dag id {dag_id} not found" raise NotFound(error_message) reset_dag_runs = data.pop('reset_dag_runs') dry_run = data.pop('dry_run') # We always pass dry_run here, otherwise this would try to confirm on the terminal! task_instances = dag.clear(dry_run=True, dag_bag=current_app.dag_bag, **data) if not dry_run: clear_task_instances( task_instances, session, dag=dag, dag_run_state=State.RUNNING if reset_dag_runs else False) task_instances = task_instances.join( DR, and_(DR.dag_id == TI.dag_id, DR.execution_date == TI.execution_date)).add_column(DR.run_id) return task_instance_reference_collection_schema.dump( TaskInstanceReferenceCollection(task_instances=task_instances.all()))
def post_clear_task_instances(dag_id: str, session=None): """ Clear task instances. """ body = request.get_json() try: data = clear_task_instance_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) dag = current_app.dag_bag.get_dag(dag_id) if not dag: error_message = "Dag id {} not found".format(dag_id) raise NotFound(error_message) reset_dag_runs = data.pop('reset_dag_runs') task_instances = dag.clear(get_tis=True, **data) if not data["dry_run"]: clear_task_instances( task_instances, session, dag=dag, activate_dag_runs=False, # We will set DagRun state later. ) if reset_dag_runs: dag.set_dag_runs_state( session=session, start_date=data["start_date"], end_date=data["end_date"], state=State.RUNNING, ) task_instances = task_instances.join( DR, and_(DR.dag_id == TI.dag_id, DR.execution_date == TI.execution_date)).add_column(DR.run_id) return task_instance_reference_collection_schema.dump( TaskInstanceReferenceCollection(task_instances=task_instances.all()))
def clear(self, start_date=None, end_date=None, upstream=False, downstream=False, session=None): """ Clears the state of task instances associated with the task, following the parameters specified. """ TI = TaskInstance qry = session.query(TI).filter(TI.dag_id == self.dag_id) if start_date: qry = qry.filter(TI.execution_date >= start_date) if end_date: qry = qry.filter(TI.execution_date <= end_date) tasks = [self.task_id] if upstream: tasks += [ t.task_id for t in self.get_flat_relatives(upstream=True)] if downstream: tasks += [ t.task_id for t in self.get_flat_relatives(upstream=False)] qry = qry.filter(TaskInstance.task_id.in_(tasks)) results = qry.all() count = len(results) clear_task_instances(results, session, dag=self.dag) session.commit() return count
def test_clear_skipped_downstream_task(self): """ After a downstream task is skipped by ShortCircuitOperator, clearing the skipped task should not cause it to be executed. """ dag = DAG( 'shortcircuit_clear_skipped_downstream_task', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE }, schedule_interval=INTERVAL, ) short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: False) downstream = DummyOperator(task_id='downstream', dag=dag) short_op >> downstream dag.clear() dr = dag.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) downstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'downstream': assert ti.state == State.SKIPPED else: raise ValueError(f'Invalid task id {ti.task_id} found!') # Clear downstream with create_session() as session: clear_task_instances([t for t in tis if t.task_id == "downstream"], session=session, dag=dag) # Run downstream again downstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # Check if the states are correct. for ti in dr.get_task_instances(): if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'downstream': assert ti.state == State.SKIPPED else: raise ValueError(f'Invalid task id {ti.task_id} found!')
def test_clear_skipped_downstream_task(self): """ After a downstream task is skipped by BranchPythonOperator, clearing the skipped task should not cause it to be executed. """ branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') branches = [self.branch_1, self.branch_2] branch_op >> branches self.dag.clear() dr = self.dag.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) for task in branches: task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_1': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_2': assert ti.state == State.SKIPPED else: raise ValueError(f'Invalid task id {ti.task_id} found!') children_tis = [ ti for ti in tis if ti.task_id in branch_op.get_direct_relative_ids() ] # Clear the children tasks. with create_session() as session: clear_task_instances(children_tis, session=session, dag=self.dag) # Run the cleared tasks again. for task in branches: task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # Check if the states are correct after children tasks are cleared. for ti in dr.get_task_instances(): if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_1': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_2': assert ti.state == State.SKIPPED else: raise ValueError(f'Invalid task id {ti.task_id} found!')
def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Clear task instances.""" body = get_json_request_dict() try: data = clear_task_instance_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: error_message = f"Dag id {dag_id} not found" raise NotFound(error_message) reset_dag_runs = data.pop('reset_dag_runs') dry_run = data.pop('dry_run') # We always pass dry_run here, otherwise this would try to confirm on the terminal! dag_run_id = data.pop('dag_run_id', None) future = data.pop('include_future', False) past = data.pop('include_past', False) downstream = data.pop('include_downstream', False) upstream = data.pop('include_upstream', False) if dag_run_id is not None: dag_run: Optional[DR] = ( session.query(DR).filter(DR.dag_id == dag_id, DR.run_id == dag_run_id).one_or_none() ) if dag_run is None: error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' raise NotFound(error_message) data['start_date'] = dag_run.logical_date data['end_date'] = dag_run.logical_date if past: data['start_date'] = None if future: data['end_date'] = None task_ids = data.pop('task_ids', None) if task_ids is not None: task_id = [task[0] if isinstance(task, tuple) else task for task in task_ids] dag = dag.partial_subset( task_ids_or_regex=task_id, include_downstream=downstream, include_upstream=upstream, ) if len(dag.task_dict) > 1: # If we had upstream/downstream etc then also include those! task_ids.extend(tid for tid in dag.task_dict if tid != task_id) task_instances = dag.clear(dry_run=True, dag_bag=get_airflow_app().dag_bag, task_ids=task_ids, **data) if not dry_run: clear_task_instances( task_instances.all(), session, dag=dag, dag_run_state=DagRunState.QUEUED if reset_dag_runs else False, ) return task_instance_reference_collection_schema.dump( TaskInstanceReferenceCollection(task_instances=task_instances.all()) )
def clear(self, start_date=None, end_date=None, upstream=False, downstream=False, session=None): """ Clears the state of task instances associated with the task, following the parameters specified. """ TI = TaskInstance qry = session.query(TI).filter(TI.dag_id == self.dag_id) if start_date: qry = qry.filter(TI.execution_date >= start_date) if end_date: qry = qry.filter(TI.execution_date <= end_date) tasks = [self.task_id] if upstream: tasks += [ t.task_id for t in self.get_flat_relatives(upstream=True)] if downstream: tasks += [ t.task_id for t in self.get_flat_relatives(upstream=False)] qry = qry.filter(TI.task_id.in_(tasks)) count = qry.count() clear_task_instances(qry.all(), session, dag=self.dag) session.commit() return count