Ejemplo n.º 1
0
def post_set_task_instances_state(dag_id, session):
    """Set a state of task instances."""
    body = request.get_json()
    try:
        data = set_task_instance_state_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    error_message = f"Dag ID {dag_id} not found"
    dag = current_app.dag_bag.get_dag(dag_id)
    if not dag:
        raise NotFound(error_message)

    task_id = data['task_id']
    task = dag.task_dict.get(task_id)

    if not task:
        error_message = f"Task ID {task_id} not found"
        raise NotFound(error_message)

    tis = dag.set_task_instance_state(
        task_id=task_id,
        execution_date=data["execution_date"],
        state=data["new_state"],
        upstream=data["include_upstream"],
        downstream=data["include_downstream"],
        future=data["include_future"],
        past=data["include_past"],
        commit=not data["dry_run"],
        session=session,
    )
    return task_instance_reference_collection_schema.dump(
        TaskInstanceReferenceCollection(task_instances=tis))
Ejemplo n.º 2
0
 def test_success(self):
     result = set_task_instance_state_form.load(self.current_input)
     expected_result = {
         'dry_run':
         True,
         'execution_date':
         dt.datetime(2020,
                     1,
                     1,
                     0,
                     0,
                     tzinfo=dt.timezone(dt.timedelta(0), '+0000')),
         'include_downstream':
         True,
         'include_future':
         True,
         'include_past':
         True,
         'include_upstream':
         True,
         'new_state':
         'failed',
         'task_id':
         'print_the_context',
     }
     self.assertEqual(expected_result, result)
Ejemplo n.º 3
0
def post_set_task_instances_state(*,
                                  dag_id: str,
                                  session: Session = NEW_SESSION
                                  ) -> APIResponse:
    """Set a state of task instances."""
    body = request.get_json()
    try:
        data = set_task_instance_state_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    error_message = f"Dag ID {dag_id} not found"
    dag = current_app.dag_bag.get_dag(dag_id)
    if not dag:
        raise NotFound(error_message)

    task_id = data['task_id']
    task = dag.task_dict.get(task_id)

    if not task:
        error_message = f"Task ID {task_id} not found"
        raise NotFound(error_message)

    execution_date = data.get('execution_date')
    run_id = data.get('dag_run_id')
    if (execution_date and (session.query(TI).filter(
            TI.task_id == task_id, TI.dag_id == dag_id, TI.execution_date
            == execution_date).one_or_none()) is None):
        raise NotFound(
            detail=
            f"Task instance not found for task {task_id!r} on execution_date {execution_date}"
        )

    if run_id and not session.query(TI).get({
            'task_id': task_id,
            'dag_id': dag_id,
            'run_id': run_id,
            'map_index': -1
    }):
        error_message = f"Task instance not found for task {task_id!r} on DAG run with ID {run_id!r}"
        raise NotFound(detail=error_message)

    tis = dag.set_task_instance_state(
        task_id=task_id,
        dag_run_id=run_id,
        execution_date=execution_date,
        state=data["new_state"],
        upstream=data["include_upstream"],
        downstream=data["include_downstream"],
        future=data["include_future"],
        past=data["include_past"],
        commit=not data["dry_run"],
        session=session,
    )
    return task_instance_reference_collection_schema.dump(
        TaskInstanceReferenceCollection(task_instances=tis))
def post_set_task_instances_state(*,
                                  dag_id: str,
                                  session: Session = NEW_SESSION
                                  ) -> APIResponse:
    """Set a state of task instances."""
    body = request.get_json()
    try:
        data = set_task_instance_state_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    error_message = f"Dag ID {dag_id} not found"
    dag = current_app.dag_bag.get_dag(dag_id)
    if not dag:
        raise NotFound(error_message)

    task_id = data['task_id']
    task = dag.task_dict.get(task_id)

    if not task:
        error_message = f"Task ID {task_id} not found"
        raise NotFound(error_message)

    execution_date = data['execution_date']
    try:
        session.query(TI).filter_by(execution_date=execution_date,
                                    task_id=task_id,
                                    dag_id=dag_id).one()
    except NoResultFound:
        raise NotFound(
            f"Task instance not found for task {task_id} on execution_date {execution_date}"
        )

    tis = dag.set_task_instance_state(
        task_id=task_id,
        execution_date=execution_date,
        state=data["new_state"],
        upstream=data["include_upstream"],
        downstream=data["include_downstream"],
        future=data["include_future"],
        past=data["include_past"],
        commit=not data["dry_run"],
        session=session,
    )
    return task_instance_reference_collection_schema.dump(
        TaskInstanceReferenceCollection(task_instances=tis))
Ejemplo n.º 5
0
def post_set_task_instances_state(dag_id, session):
    """Set a state of task instances."""
    body = request.get_json()
    try:
        data = set_task_instance_state_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    error_message = f"Dag ID {dag_id} not found"
    try:
        dag = current_app.dag_bag.get_dag(dag_id)
        if not dag:
            raise NotFound(error_message)
    except SerializedDagNotFound:
        # If DAG is not found in serialized_dag table
        raise NotFound(error_message)

    task_id = data['task_id']
    task = dag.task_dict.get(task_id)

    if not task:
        error_message = f"Task ID {task_id} not found"
        raise NotFound(error_message)

    tis = set_state(
        tasks=[task],
        execution_date=data["execution_date"],
        upstream=data["include_upstream"],
        downstream=data["include_downstream"],
        future=data["include_future"],
        past=data["include_past"],
        state=data["new_state"],
        commit=not data["dry_run"],
    )
    execution_dates = {ti.execution_date for ti in tis}
    execution_date_to_run_id_map = dict(
        session.query(DR.execution_date, DR.run_id).filter(
            DR.dag_id == dag_id, DR.execution_date.in_(execution_dates)
        )
    )
    tis_with_run_id = [(ti, execution_date_to_run_id_map.get(ti.execution_date)) for ti in tis]
    return task_instance_reference_collection_schema.dump(
        TaskInstanceReferenceCollection(task_instances=tis_with_run_id)
    )
Ejemplo n.º 6
0
    def test_validation_error(self, override_data):
        self.current_input.update(override_data)

        with pytest.raises(ValidationError):
            set_task_instance_state_form.load(self.current_input)