Пример #1
0
def patch_pool(
    *,
    pool_name: str,
    update_mask: UpdateMask = None,
    session: Session = NEW_SESSION,
) -> APIResponse:
    """Update a pool"""
    request_dict = get_json_request_dict()
    # Only slots can be modified in 'default_pool'
    try:
        if pool_name == Pool.DEFAULT_POOL_NAME and request_dict[
                "name"] != Pool.DEFAULT_POOL_NAME:
            if update_mask and len(
                    update_mask) == 1 and update_mask[0].strip() == "slots":
                pass
            else:
                raise BadRequest(
                    detail="Default Pool's name can't be modified")
    except KeyError:
        pass

    pool = session.query(Pool).filter(Pool.pool == pool_name).first()
    if not pool:
        raise NotFound(detail=f"Pool with name:'{pool_name}' not found")

    try:
        patch_body = pool_schema.load(request_dict)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    if update_mask:
        update_mask = [i.strip() for i in update_mask]
        _patch_body = {}
        try:
            update_mask = [
                pool_schema.declared_fields[field].attribute
                if pool_schema.declared_fields[field].attribute else field
                for field in update_mask
            ]
        except KeyError as err:
            raise BadRequest(
                detail=f"Invalid field: {err.args[0]} in update mask")
        _patch_body = {field: patch_body[field] for field in update_mask}
        patch_body = _patch_body

    else:
        required_fields = {"name", "slots"}
        fields_diff = required_fields - set(get_json_request_dict().keys())
        if fields_diff:
            raise BadRequest(
                detail=f"Missing required property(ies): {sorted(fields_diff)}"
            )

    for key, value in patch_body.items():
        setattr(pool, key, value)
    session.commit()
    return pool_schema.dump(pool)
Пример #2
0
def update_dag_run_state(*,
                         dag_id: str,
                         dag_run_id: str,
                         session: Session = NEW_SESSION) -> APIResponse:
    """Set a state of a dag run."""
    dag_run: Optional[DagRun] = (session.query(DagRun).filter(
        DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none())
    if dag_run is None:
        error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}'
        raise NotFound(error_message)
    try:
        post_body = set_dagrun_state_form_schema.load(get_json_request_dict())
    except ValidationError as err:
        raise BadRequest(detail=str(err))

    state = post_body['state']
    dag = get_airflow_app().dag_bag.get_dag(dag_id)
    if state == DagRunState.SUCCESS:
        set_dag_run_state_to_success(dag=dag,
                                     run_id=dag_run.run_id,
                                     commit=True)
    elif state == DagRunState.QUEUED:
        set_dag_run_state_to_queued(dag=dag,
                                    run_id=dag_run.run_id,
                                    commit=True)
    else:
        set_dag_run_state_to_failed(dag=dag,
                                    run_id=dag_run.run_id,
                                    commit=True)
    dag_run = session.query(DagRun).get(dag_run.id)
    return dagrun_schema.dump(dag_run)
Пример #3
0
def post_clear_task_instances(*,
                              dag_id: str,
                              session: Session = NEW_SESSION) -> APIResponse:
    """Clear task instances."""
    body = get_json_request_dict()
    try:
        data = clear_task_instance_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    dag = get_airflow_app().dag_bag.get_dag(dag_id)
    if not dag:
        error_message = f"Dag id {dag_id} not found"
        raise NotFound(error_message)
    reset_dag_runs = data.pop('reset_dag_runs')
    dry_run = data.pop('dry_run')
    # We always pass dry_run here, otherwise this would try to confirm on the terminal!
    task_instances = dag.clear(dry_run=True,
                               dag_bag=get_airflow_app().dag_bag,
                               **data)
    if not dry_run:
        clear_task_instances(
            task_instances.all(),
            session,
            dag=dag,
            dag_run_state=DagRunState.QUEUED if reset_dag_runs else False,
        )

    return task_instance_reference_collection_schema.dump(
        TaskInstanceReferenceCollection(task_instances=task_instances.all()))
Пример #4
0
def post_variables() -> Response:
    """Create a variable"""
    try:
        data = variable_schema.load(get_json_request_dict())

    except ValidationError as err:
        raise BadRequest("Invalid Variable schema", detail=str(err.messages))
    Variable.set(data["key"], data["val"])
    return variable_schema.dump(data)
def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
    """Clear task instances."""
    body = get_json_request_dict()
    try:
        data = clear_task_instance_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    dag = get_airflow_app().dag_bag.get_dag(dag_id)
    if not dag:
        error_message = f"Dag id {dag_id} not found"
        raise NotFound(error_message)
    reset_dag_runs = data.pop('reset_dag_runs')
    dry_run = data.pop('dry_run')
    # We always pass dry_run here, otherwise this would try to confirm on the terminal!
    dag_run_id = data.pop('dag_run_id', None)
    future = data.pop('include_future', False)
    past = data.pop('include_past', False)
    downstream = data.pop('include_downstream', False)
    upstream = data.pop('include_upstream', False)
    if dag_run_id is not None:
        dag_run: Optional[DR] = (
            session.query(DR).filter(DR.dag_id == dag_id, DR.run_id == dag_run_id).one_or_none()
        )
        if dag_run is None:
            error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}'
            raise NotFound(error_message)
        data['start_date'] = dag_run.logical_date
        data['end_date'] = dag_run.logical_date
    if past:
        data['start_date'] = None
    if future:
        data['end_date'] = None
    task_ids = data.pop('task_ids', None)
    if task_ids is not None:
        task_id = [task[0] if isinstance(task, tuple) else task for task in task_ids]
        dag = dag.partial_subset(
            task_ids_or_regex=task_id,
            include_downstream=downstream,
            include_upstream=upstream,
        )

        if len(dag.task_dict) > 1:
            # If we had upstream/downstream etc then also include those!
            task_ids.extend(tid for tid in dag.task_dict if tid != task_id)
    task_instances = dag.clear(dry_run=True, dag_bag=get_airflow_app().dag_bag, task_ids=task_ids, **data)
    if not dry_run:
        clear_task_instances(
            task_instances.all(),
            session,
            dag=dag,
            dag_run_state=DagRunState.QUEUED if reset_dag_runs else False,
        )

    return task_instance_reference_collection_schema.dump(
        TaskInstanceReferenceCollection(task_instances=task_instances.all())
    )
Пример #6
0
def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse:
    """Get list of task instances."""
    body = get_json_request_dict()
    try:
        data = task_instance_batch_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))
    states = _convert_state(data['state'])
    base_query = session.query(TI).join(TI.dag_run)

    base_query = _apply_array_filter(base_query,
                                     key=TI.dag_id,
                                     values=data["dag_ids"])
    base_query = _apply_range_filter(
        base_query,
        key=DR.execution_date,
        value_range=(data["execution_date_gte"], data["execution_date_lte"]),
    )
    base_query = _apply_range_filter(
        base_query,
        key=TI.start_date,
        value_range=(data["start_date_gte"], data["start_date_lte"]),
    )
    base_query = _apply_range_filter(base_query,
                                     key=TI.end_date,
                                     value_range=(data["end_date_gte"],
                                                  data["end_date_lte"]))
    base_query = _apply_range_filter(base_query,
                                     key=TI.duration,
                                     value_range=(data["duration_gte"],
                                                  data["duration_lte"]))
    base_query = _apply_array_filter(base_query, key=TI.state, values=states)
    base_query = _apply_array_filter(base_query,
                                     key=TI.pool,
                                     values=data["pool"])
    base_query = _apply_array_filter(base_query,
                                     key=TI.queue,
                                     values=data["queue"])

    # Count elements before joining extra columns
    total_entries = base_query.with_entities(func.count('*')).scalar()
    # Add join
    base_query = base_query.join(
        SlaMiss,
        and_(
            SlaMiss.dag_id == TI.dag_id,
            SlaMiss.task_id == TI.task_id,
            SlaMiss.execution_date == DR.execution_date,
        ),
        isouter=True,
    ).add_entity(SlaMiss)
    ti_query = base_query.options(joinedload(TI.rendered_task_instance_fields))
    task_instances = ti_query.all()

    return task_instance_collection_schema.dump(
        TaskInstanceCollection(task_instances=task_instances,
                               total_entries=total_entries))
Пример #7
0
def post_set_task_instances_state(*,
                                  dag_id: str,
                                  session: Session = NEW_SESSION
                                  ) -> APIResponse:
    """Set a state of task instances."""
    body = get_json_request_dict()
    try:
        data = set_task_instance_state_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    error_message = f"Dag ID {dag_id} not found"
    dag = get_airflow_app().dag_bag.get_dag(dag_id)
    if not dag:
        raise NotFound(error_message)

    task_id = data['task_id']
    task = dag.task_dict.get(task_id)

    if not task:
        error_message = f"Task ID {task_id} not found"
        raise NotFound(error_message)

    execution_date = data.get('execution_date')
    run_id = data.get('dag_run_id')
    if (execution_date and (session.query(TI).filter(
            TI.task_id == task_id, TI.dag_id == dag_id, TI.execution_date
            == execution_date).one_or_none()) is None):
        raise NotFound(
            detail=
            f"Task instance not found for task {task_id!r} on execution_date {execution_date}"
        )

    if run_id and not session.query(TI).get({
            'task_id': task_id,
            'dag_id': dag_id,
            'run_id': run_id,
            'map_index': -1
    }):
        error_message = f"Task instance not found for task {task_id!r} on DAG run with ID {run_id!r}"
        raise NotFound(detail=error_message)

    tis = dag.set_task_instance_state(
        task_id=task_id,
        run_id=run_id,
        execution_date=execution_date,
        state=data["new_state"],
        upstream=data["include_upstream"],
        downstream=data["include_downstream"],
        future=data["include_future"],
        past=data["include_past"],
        commit=not data["dry_run"],
        session=session,
    )
    return task_instance_reference_collection_schema.dump(
        TaskInstanceReferenceCollection(task_instances=tis))
Пример #8
0
def post_dag_run(*,
                 dag_id: str,
                 session: Session = NEW_SESSION) -> APIResponse:
    """Trigger a DAG."""
    dm = session.query(DagModel).filter(DagModel.dag_id == dag_id).first()
    if not dm:
        raise NotFound(title="DAG not found",
                       detail=f"DAG with dag_id: '{dag_id}' not found")
    if dm.has_import_errors:
        raise BadRequest(
            title="DAG cannot be triggered",
            detail=f"DAG with dag_id: '{dag_id}' has import errors",
        )
    try:
        post_body = dagrun_schema.load(get_json_request_dict(),
                                       session=session)
    except ValidationError as err:
        raise BadRequest(detail=str(err))

    logical_date = pendulum.instance(post_body["execution_date"])
    run_id = post_body["run_id"]
    dagrun_instance = (session.query(DagRun).filter(
        DagRun.dag_id == dag_id,
        or_(DagRun.run_id == run_id, DagRun.execution_date == logical_date),
    ).first())
    if not dagrun_instance:
        try:
            dag = get_airflow_app().dag_bag.get_dag(dag_id)
            dag_run = dag.create_dagrun(
                run_type=DagRunType.MANUAL,
                run_id=run_id,
                execution_date=logical_date,
                data_interval=dag.timetable.infer_manual_data_interval(
                    run_after=logical_date),
                state=DagRunState.QUEUED,
                conf=post_body.get("conf"),
                external_trigger=True,
                dag_hash=get_airflow_app().dag_bag.dags_hash.get(dag_id),
            )
            return dagrun_schema.dump(dag_run)
        except ValueError as ve:
            raise BadRequest(detail=str(ve))

    if dagrun_instance.execution_date == logical_date:
        raise AlreadyExists(detail=(
            f"DAGRun with DAG ID: '{dag_id}' and "
            f"DAGRun logical date: '{logical_date.isoformat(sep=' ')}' already exists"
        ), )

    raise AlreadyExists(
        detail=
        f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: '{run_id}' already exists"
    )
Пример #9
0
def post_pool(*, session: Session = NEW_SESSION) -> APIResponse:
    """Create a pool"""
    required_fields = {"name", "slots"
                       }  # Pool would require both fields in the post request
    fields_diff = required_fields - set(get_json_request_dict().keys())
    if fields_diff:
        raise BadRequest(
            detail=f"Missing required property(ies): {sorted(fields_diff)}")

    try:
        post_body = pool_schema.load(get_json_request_dict(), session=session)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    pool = Pool(**post_body)
    try:
        session.add(pool)
        session.commit()
        return pool_schema.dump(pool)
    except IntegrityError:
        raise AlreadyExists(detail=f"Pool: {post_body['pool']} already exists")
Пример #10
0
def patch_variable(*, variable_key: str, update_mask: UpdateMask = None) -> Response:
    """Update a variable by key"""
    try:
        data = variable_schema.load(get_json_request_dict())
    except ValidationError as err:
        raise BadRequest("Invalid Variable schema", detail=str(err.messages))

    if data["key"] != variable_key:
        raise BadRequest("Invalid post body", detail="key from request body doesn't match uri parameter")

    if update_mask:
        if "key" in update_mask:
            raise BadRequest("key is a ready only field")
        if "value" not in update_mask:
            raise BadRequest("No field to update")

    Variable.set(data["key"], data["val"])
    return variable_schema.dump(data)
Пример #11
0
def clear_dag_run(*,
                  dag_id: str,
                  dag_run_id: str,
                  session: Session = NEW_SESSION) -> APIResponse:
    """Clear a dag run."""
    dag_run: Optional[DagRun] = (session.query(DagRun).filter(
        DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none())
    if dag_run is None:
        error_message = f'Dag Run id {dag_run_id} not found in dag   {dag_id}'
        raise NotFound(error_message)
    try:
        post_body = clear_dagrun_form_schema.load(get_json_request_dict())
    except ValidationError as err:
        raise BadRequest(detail=str(err))

    dry_run = post_body.get('dry_run', False)
    dag = get_airflow_app().dag_bag.get_dag(dag_id)
    start_date = dag_run.logical_date
    end_date = dag_run.logical_date

    if dry_run:
        task_instances = dag.clear(
            start_date=start_date,
            end_date=end_date,
            task_ids=None,
            include_subdags=True,
            include_parentdag=True,
            only_failed=False,
            dry_run=True,
        )
        return task_instance_reference_collection_schema.dump(
            TaskInstanceReferenceCollection(task_instances=task_instances))
    else:
        dag.clear(
            start_date=start_date,
            end_date=end_date,
            task_ids=None,
            include_subdags=True,
            include_parentdag=True,
            only_failed=False,
        )
        dag_run.refresh_from_db()
        return dagrun_schema.dump(dag_run)
Пример #12
0
def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse:
    """Get list of DAG Runs"""
    body = get_json_request_dict()
    try:
        data = dagruns_batch_form_schema.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    appbuilder = get_airflow_app().appbuilder
    readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user)
    query = session.query(DagRun)
    if data.get("dag_ids"):
        dag_ids = set(data["dag_ids"]) & set(readable_dag_ids)
        query = query.filter(DagRun.dag_id.in_(dag_ids))
    else:
        query = query.filter(DagRun.dag_id.in_(readable_dag_ids))

    states = data.get("states")
    if states:
        query = query.filter(DagRun.state.in_(states))

    dag_runs, total_entries = _fetch_dag_runs(
        query,
        end_date_gte=data["end_date_gte"],
        end_date_lte=data["end_date_lte"],
        execution_date_gte=data["execution_date_gte"],
        execution_date_lte=data["execution_date_lte"],
        start_date_gte=data["start_date_gte"],
        start_date_lte=data["start_date_lte"],
        limit=data["page_limit"],
        offset=data["page_offset"],
        order_by=data.get("order_by", "id"),
    )

    return dagrun_collection_schema.dump(
        DAGRunCollection(dag_runs=dag_runs, total_entries=total_entries))