Example #1
0
def post_clear_task_instances(*,
                              dag_id: str,
                              session: Session = NEW_SESSION) -> APIResponse:
    """Clear task instances."""
    body = request.get_json()
    try:
        data = clear_task_instance_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    dag = current_app.dag_bag.get_dag(dag_id)
    if not dag:
        error_message = f"Dag id {dag_id} not found"
        raise NotFound(error_message)
    reset_dag_runs = data.pop('reset_dag_runs')
    dry_run = data.pop('dry_run')
    # We always pass dry_run here, otherwise this would try to confirm on the terminal!
    task_instances = dag.clear(dry_run=True,
                               dag_bag=current_app.dag_bag,
                               **data)
    if not dry_run:
        clear_task_instances(
            task_instances.all(),
            session,
            dag=dag,
            dag_run_state=DagRunState.QUEUED if reset_dag_runs else False,
        )

    return task_instance_reference_collection_schema.dump(
        TaskInstanceReferenceCollection(task_instances=task_instances.all()))
Example #2
0
def post_clear_task_instances(dag_id: str, session=None):
    """
    Clear task instances.
    """
    body = request.get_json()
    try:
        data = clear_task_instance_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    dag = current_app.dag_bag.get_dag(dag_id)
    if not dag:
        error_message = "Dag id {} not found".format(dag_id)
        raise NotFound(error_message)
    reset_dag_runs = data.pop('reset_dag_runs')
    task_instances = dag.clear(get_tis=True, **data)
    if not data["dry_run"]:
        clear_task_instances(
            task_instances,
            session,
            dag=dag,
            activate_dag_runs=False,  # We will set DagRun state later.
        )
        if reset_dag_runs:
            dag.set_dag_runs_state(
                session=session,
                start_date=data["start_date"],
                end_date=data["end_date"],
                state=State.RUNNING,
            )
    task_instances = task_instances.join(
        DR, and_(DR.dag_id == TI.dag_id,
                 DR.execution_date == TI.execution_date)).add_column(DR.run_id)
    return task_instance_reference_collection_schema.dump(
        TaskInstanceReferenceCollection(task_instances=task_instances.all()))
Example #3
0
def post_clear_task_instances(dag_id: str, session=None):
    """Clear task instances."""
    body = request.get_json()
    try:
        data = clear_task_instance_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    dag = current_app.dag_bag.get_dag(dag_id)
    if not dag:
        error_message = f"Dag id {dag_id} not found"
        raise NotFound(error_message)
    reset_dag_runs = data.pop('reset_dag_runs')
    dry_run = data.pop('dry_run')
    # We always pass dry_run here, otherwise this would try to confirm on the terminal!
    task_instances = dag.clear(dry_run=True,
                               dag_bag=current_app.dag_bag,
                               **data)
    if not dry_run:
        clear_task_instances(
            task_instances,
            session,
            dag=dag,
            dag_run_state=State.RUNNING if reset_dag_runs else False)
    task_instances = task_instances.join(
        DR, and_(DR.dag_id == TI.dag_id,
                 DR.execution_date == TI.execution_date)).add_column(DR.run_id)
    return task_instance_reference_collection_schema.dump(
        TaskInstanceReferenceCollection(task_instances=task_instances.all()))
def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
    """Clear task instances."""
    body = get_json_request_dict()
    try:
        data = clear_task_instance_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    dag = get_airflow_app().dag_bag.get_dag(dag_id)
    if not dag:
        error_message = f"Dag id {dag_id} not found"
        raise NotFound(error_message)
    reset_dag_runs = data.pop('reset_dag_runs')
    dry_run = data.pop('dry_run')
    # We always pass dry_run here, otherwise this would try to confirm on the terminal!
    dag_run_id = data.pop('dag_run_id', None)
    future = data.pop('include_future', False)
    past = data.pop('include_past', False)
    downstream = data.pop('include_downstream', False)
    upstream = data.pop('include_upstream', False)
    if dag_run_id is not None:
        dag_run: Optional[DR] = (
            session.query(DR).filter(DR.dag_id == dag_id, DR.run_id == dag_run_id).one_or_none()
        )
        if dag_run is None:
            error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}'
            raise NotFound(error_message)
        data['start_date'] = dag_run.logical_date
        data['end_date'] = dag_run.logical_date
    if past:
        data['start_date'] = None
    if future:
        data['end_date'] = None
    task_ids = data.pop('task_ids', None)
    if task_ids is not None:
        task_id = [task[0] if isinstance(task, tuple) else task for task in task_ids]
        dag = dag.partial_subset(
            task_ids_or_regex=task_id,
            include_downstream=downstream,
            include_upstream=upstream,
        )

        if len(dag.task_dict) > 1:
            # If we had upstream/downstream etc then also include those!
            task_ids.extend(tid for tid in dag.task_dict if tid != task_id)
    task_instances = dag.clear(dry_run=True, dag_bag=get_airflow_app().dag_bag, task_ids=task_ids, **data)
    if not dry_run:
        clear_task_instances(
            task_instances.all(),
            session,
            dag=dag,
            dag_run_state=DagRunState.QUEUED if reset_dag_runs else False,
        )

    return task_instance_reference_collection_schema.dump(
        TaskInstanceReferenceCollection(task_instances=task_instances.all())
    )
 def test_validation_error(self, payload):
     with pytest.raises(ValidationError):
         clear_task_instance_form.load(payload)
    def test_validation_error(self, override_data):
        self.current_input.update(override_data)

        with self.assertRaises(ValidationError):
            clear_task_instance_form.load(self.current_input)