def post_set_task_instances_state(dag_id, session): """Set a state of task instances.""" body = request.get_json() try: data = set_task_instance_state_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) error_message = f"Dag ID {dag_id} not found" dag = current_app.dag_bag.get_dag(dag_id) if not dag: raise NotFound(error_message) task_id = data['task_id'] task = dag.task_dict.get(task_id) if not task: error_message = f"Task ID {task_id} not found" raise NotFound(error_message) tis = dag.set_task_instance_state( task_id=task_id, execution_date=data["execution_date"], state=data["new_state"], upstream=data["include_upstream"], downstream=data["include_downstream"], future=data["include_future"], past=data["include_past"], commit=not data["dry_run"], session=session, ) return task_instance_reference_collection_schema.dump( TaskInstanceReferenceCollection(task_instances=tis))
def test_success(self): result = set_task_instance_state_form.load(self.current_input) expected_result = { 'dry_run': True, 'execution_date': dt.datetime(2020, 1, 1, 0, 0, tzinfo=dt.timezone(dt.timedelta(0), '+0000')), 'include_downstream': True, 'include_future': True, 'include_past': True, 'include_upstream': True, 'new_state': 'failed', 'task_id': 'print_the_context', } self.assertEqual(expected_result, result)
def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION ) -> APIResponse: """Set a state of task instances.""" body = request.get_json() try: data = set_task_instance_state_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) error_message = f"Dag ID {dag_id} not found" dag = current_app.dag_bag.get_dag(dag_id) if not dag: raise NotFound(error_message) task_id = data['task_id'] task = dag.task_dict.get(task_id) if not task: error_message = f"Task ID {task_id} not found" raise NotFound(error_message) execution_date = data.get('execution_date') run_id = data.get('dag_run_id') if (execution_date and (session.query(TI).filter( TI.task_id == task_id, TI.dag_id == dag_id, TI.execution_date == execution_date).one_or_none()) is None): raise NotFound( detail= f"Task instance not found for task {task_id!r} on execution_date {execution_date}" ) if run_id and not session.query(TI).get({ 'task_id': task_id, 'dag_id': dag_id, 'run_id': run_id, 'map_index': -1 }): error_message = f"Task instance not found for task {task_id!r} on DAG run with ID {run_id!r}" raise NotFound(detail=error_message) tis = dag.set_task_instance_state( task_id=task_id, dag_run_id=run_id, execution_date=execution_date, state=data["new_state"], upstream=data["include_upstream"], downstream=data["include_downstream"], future=data["include_future"], past=data["include_past"], commit=not data["dry_run"], session=session, ) return task_instance_reference_collection_schema.dump( TaskInstanceReferenceCollection(task_instances=tis))
def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION ) -> APIResponse: """Set a state of task instances.""" body = request.get_json() try: data = set_task_instance_state_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) error_message = f"Dag ID {dag_id} not found" dag = current_app.dag_bag.get_dag(dag_id) if not dag: raise NotFound(error_message) task_id = data['task_id'] task = dag.task_dict.get(task_id) if not task: error_message = f"Task ID {task_id} not found" raise NotFound(error_message) execution_date = data['execution_date'] try: session.query(TI).filter_by(execution_date=execution_date, task_id=task_id, dag_id=dag_id).one() except NoResultFound: raise NotFound( f"Task instance not found for task {task_id} on execution_date {execution_date}" ) tis = dag.set_task_instance_state( task_id=task_id, execution_date=execution_date, state=data["new_state"], upstream=data["include_upstream"], downstream=data["include_downstream"], future=data["include_future"], past=data["include_past"], commit=not data["dry_run"], session=session, ) return task_instance_reference_collection_schema.dump( TaskInstanceReferenceCollection(task_instances=tis))
def post_set_task_instances_state(dag_id, session): """Set a state of task instances.""" body = request.get_json() try: data = set_task_instance_state_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) error_message = f"Dag ID {dag_id} not found" try: dag = current_app.dag_bag.get_dag(dag_id) if not dag: raise NotFound(error_message) except SerializedDagNotFound: # If DAG is not found in serialized_dag table raise NotFound(error_message) task_id = data['task_id'] task = dag.task_dict.get(task_id) if not task: error_message = f"Task ID {task_id} not found" raise NotFound(error_message) tis = set_state( tasks=[task], execution_date=data["execution_date"], upstream=data["include_upstream"], downstream=data["include_downstream"], future=data["include_future"], past=data["include_past"], state=data["new_state"], commit=not data["dry_run"], ) execution_dates = {ti.execution_date for ti in tis} execution_date_to_run_id_map = dict( session.query(DR.execution_date, DR.run_id).filter( DR.dag_id == dag_id, DR.execution_date.in_(execution_dates) ) ) tis_with_run_id = [(ti, execution_date_to_run_id_map.get(ti.execution_date)) for ti in tis] return task_instance_reference_collection_schema.dump( TaskInstanceReferenceCollection(task_instances=tis_with_run_id) )
def test_validation_error(self, override_data): self.current_input.update(override_data) with pytest.raises(ValidationError): set_task_instance_state_form.load(self.current_input)