Beispiel #1
0
def _check_action_and_resource(sm: AirflowSecurityManager,
                               perms: List[Tuple[str, str]]) -> None:
    """
    Checks if the action or resource exists and raise 400 if not

    This function is intended for use in the REST API because it raise 400
    """
    for action, resource in perms:
        if not sm.get_action(action):
            raise BadRequest(
                detail=f"The specified action: {action!r} was not found")
        if not sm.get_resource(resource):
            raise BadRequest(
                detail=f"The specified resource: {resource!r} was not found")
def get_log(session, dag_id, dag_run_id, task_id, task_try_number, full_content=False, token=None):
    """Get logs for specific task instance"""
    key = current_app.config["SECRET_KEY"]
    if not token:
        metadata = {}
    else:
        try:
            metadata = URLSafeSerializer(key).loads(token)
        except BadSignature:
            raise BadRequest("Bad Signature. Please use only the tokens provided by the API.")

    if metadata.get('download_logs') and metadata['download_logs']:
        full_content = True

    if full_content:
        metadata['download_logs'] = True
    else:
        metadata['download_logs'] = False

    task_log_reader = TaskLogReader()
    if not task_log_reader.supports_read:
        raise BadRequest("Task log handler does not support read logs.")

    query = session.query(DagRun).filter(DagRun.dag_id == dag_id)
    dag_run = query.filter(DagRun.run_id == dag_run_id).first()
    if not dag_run:
        raise NotFound("DAG Run not found")

    ti = dag_run.get_task_instance(task_id, session)
    if ti is None:
        metadata['end_of_log'] = True
        raise BadRequest(detail="Task instance did not exist in the DB")

    dag = current_app.dag_bag.get_dag(dag_id)
    if dag:
        ti.task = dag.get_task(ti.task_id)

    return_type = request.accept_mimetypes.best_match(['text/plain', 'application/json'])

    # return_type would be either the above two or None

    if return_type == 'application/json' or return_type is None:  # default
        logs, metadata = task_log_reader.read_log_chunks(ti, task_try_number, metadata)
        logs = logs[0] if task_try_number is not None else logs
        token = URLSafeSerializer(key).dumps(metadata)
        return logs_schema.dump(LogResponseObject(continuation_token=token, content=logs))
    # text/plain. Stream
    logs = task_log_reader.read_log_stream(ti, task_try_number, metadata)

    return Response(logs, headers={"Content-Type": return_type})
Beispiel #3
0
def patch_pool(pool_name, session, update_mask=None):
    """
    Update a pool
    """
    # Only slots can be modified in 'default_pool'
    try:
        if pool_name == Pool.DEFAULT_POOL_NAME and request.json[
                "name"] != Pool.DEFAULT_POOL_NAME:
            if update_mask and len(
                    update_mask) == 1 and update_mask[0].strip() == "slots":
                pass
            else:
                raise BadRequest(
                    detail="Default Pool's name can't be modified")
    except KeyError:
        pass

    pool = session.query(Pool).filter(Pool.pool == pool_name).first()
    if not pool:
        raise NotFound(detail=f"Pool with name:'{pool_name}' not found")

    try:
        patch_body = pool_schema.load(request.json)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    if update_mask:
        update_mask = [i.strip() for i in update_mask]
        _patch_body = {}
        try:
            update_mask = [
                pool_schema.declared_fields[field].attribute
                if pool_schema.declared_fields[field].attribute else field
                for field in update_mask
            ]
        except KeyError as err:
            raise BadRequest(
                detail=f"Invalid field: {err.args[0]} in update mask")
        _patch_body = {field: patch_body[field] for field in update_mask}
        patch_body = _patch_body

    else:
        for field in ["name", "slots"]:
            if field not in request.json.keys():
                raise BadRequest(detail=f"'{field}' is a required property")

    for key, value in patch_body.items():
        setattr(pool, key, value)
    session.commit()
    return pool_schema.dump(pool)
Beispiel #4
0
def get_dag_runs_batch(session):
    """Get list of DAG Runs"""
    body = request.get_json()
    try:
        data = dagruns_batch_form_schema.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    appbuilder = current_app.appbuilder
    readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user)
    query = session.query(DagRun)
    if data.get("dag_ids"):
        dag_ids = set(data["dag_ids"]) & set(readable_dag_ids)
        query = query.filter(DagRun.dag_id.in_(dag_ids))
    else:
        query = query.filter(DagRun.dag_id.in_(readable_dag_ids))

    dag_runs, total_entries = _fetch_dag_runs(
        query,
        data["end_date_gte"],
        data["end_date_lte"],
        data["execution_date_gte"],
        data["execution_date_lte"],
        data["start_date_gte"],
        data["start_date_lte"],
        data["page_limit"],
        data["page_offset"],
        order_by=data.get('order_by', "id"),
    )

    return dagrun_collection_schema.dump(
        DAGRunCollection(dag_runs=dag_runs, total_entries=total_entries))
Beispiel #5
0
def post_clear_task_instances(dag_id: str, session=None):
    """Clear task instances."""
    body = request.get_json()
    try:
        data = clear_task_instance_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    dag = current_app.dag_bag.get_dag(dag_id)
    if not dag:
        error_message = f"Dag id {dag_id} not found"
        raise NotFound(error_message)
    reset_dag_runs = data.pop('reset_dag_runs')
    dry_run = data.pop('dry_run')
    # We always pass dry_run here, otherwise this would try to confirm on the terminal!
    task_instances = dag.clear(dry_run=True,
                               dag_bag=current_app.dag_bag,
                               **data)
    if not dry_run:
        clear_task_instances(
            task_instances,
            session,
            dag=dag,
            dag_run_state=State.RUNNING if reset_dag_runs else False)
    task_instances = task_instances.join(
        DR, and_(DR.dag_id == TI.dag_id,
                 DR.execution_date == TI.execution_date)).add_column(DR.run_id)
    return task_instance_reference_collection_schema.dump(
        TaskInstanceReferenceCollection(task_instances=task_instances.all()))
Beispiel #6
0
def get_task_instances_batch(session=None):
    """
    Get list of task instances.
    """
    body = request.get_json()
    try:
        data = task_instance_batch_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))
    base_query = session.query(TI)

    base_query = _apply_array_filter(base_query,
                                     key=TI.dag_id,
                                     values=data["dag_ids"])
    base_query = _apply_range_filter(
        base_query,
        key=TI.execution_date,
        value_range=(data["execution_date_gte"], data["execution_date_lte"]),
    )
    base_query = _apply_range_filter(
        base_query,
        key=TI.start_date,
        value_range=(data["start_date_gte"], data["start_date_lte"]),
    )
    base_query = _apply_range_filter(base_query,
                                     key=TI.end_date,
                                     value_range=(data["end_date_gte"],
                                                  data["end_date_lte"]))
    base_query = _apply_range_filter(base_query,
                                     key=TI.duration,
                                     value_range=(data["duration_gte"],
                                                  data["duration_lte"]))
    base_query = _apply_array_filter(base_query,
                                     key=TI.state,
                                     values=data["state"])
    base_query = _apply_array_filter(base_query,
                                     key=TI.pool,
                                     values=data["pool"])
    base_query = _apply_array_filter(base_query,
                                     key=TI.queue,
                                     values=data["queue"])

    # Count elements before joining extra columns
    total_entries = base_query.with_entities(func.count('*')).scalar()
    # Add join
    base_query = base_query.join(
        SlaMiss,
        and_(
            SlaMiss.dag_id == TI.dag_id,
            SlaMiss.task_id == TI.task_id,
            SlaMiss.execution_date == TI.execution_date,
        ),
        isouter=True,
    )
    ti_query = base_query.add_entity(SlaMiss)
    task_instances = ti_query.all()

    return task_instance_collection_schema.dump(
        TaskInstanceCollection(task_instances=task_instances,
                               total_entries=total_entries))
Beispiel #7
0
def get_users(*,
              limit: int,
              order_by: str = "id",
              offset: Optional[str] = None) -> APIResponse:
    """Get users"""
    appbuilder = get_airflow_app().appbuilder
    session = appbuilder.get_session
    total_entries = session.query(func.count(User.id)).scalar()
    direction = desc if order_by.startswith("-") else asc
    to_replace = {"user_id": "id"}
    order_param = order_by.strip("-")
    order_param = to_replace.get(order_param, order_param)
    allowed_filter_attrs = [
        'id',
        "first_name",
        "last_name",
        "user_name",
        "email",
        "is_active",
        "role",
    ]
    if order_by not in allowed_filter_attrs:
        raise BadRequest(detail=f"Ordering with '{order_by}' is disallowed or "
                         f"the attribute does not exist on the model")

    query = session.query(User)
    users = query.order_by(direction(getattr(
        User, order_param))).offset(offset).limit(limit).all()

    return user_collection_schema.dump(
        UserCollection(users=users, total_entries=total_entries))
Beispiel #8
0
def get_dag_runs_batch(session):
    """
    Get list of DAG Runs
    """
    body = request.get_json()
    try:
        data = dagruns_batch_form_schema.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    query = session.query(DagRun)

    if data["dag_ids"]:
        query = query.filter(DagRun.dag_id.in_(data["dag_ids"]))

    dag_runs, total_entries = _fetch_dag_runs(
        query,
        session,
        data["end_date_gte"],
        data["end_date_lte"],
        data["execution_date_gte"],
        data["execution_date_lte"],
        data["start_date_gte"],
        data["start_date_lte"],
        data["page_limit"],
        data["page_offset"],
    )

    return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs, total_entries=total_entries))
Beispiel #9
0
def post_clear_task_instances(*,
                              dag_id: str,
                              session: Session = NEW_SESSION) -> APIResponse:
    """Clear task instances."""
    body = request.get_json()
    try:
        data = clear_task_instance_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    dag = current_app.dag_bag.get_dag(dag_id)
    if not dag:
        error_message = f"Dag id {dag_id} not found"
        raise NotFound(error_message)
    reset_dag_runs = data.pop('reset_dag_runs')
    dry_run = data.pop('dry_run')
    # We always pass dry_run here, otherwise this would try to confirm on the terminal!
    task_instances = dag.clear(dry_run=True,
                               dag_bag=current_app.dag_bag,
                               **data)
    if not dry_run:
        clear_task_instances(
            task_instances.all(),
            session,
            dag=dag,
            dag_run_state=DagRunState.QUEUED if reset_dag_runs else False,
        )

    return task_instance_reference_collection_schema.dump(
        TaskInstanceReferenceCollection(task_instances=task_instances.all()))
def post_set_task_instances_state(dag_id, session):
    """Set a state of task instances."""
    body = request.get_json()
    try:
        data = set_task_instance_state_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    error_message = f"Dag ID {dag_id} not found"
    dag = current_app.dag_bag.get_dag(dag_id)
    if not dag:
        raise NotFound(error_message)

    task_id = data['task_id']
    task = dag.task_dict.get(task_id)

    if not task:
        error_message = f"Task ID {task_id} not found"
        raise NotFound(error_message)

    tis = dag.set_task_instance_state(
        task_id=task_id,
        execution_date=data["execution_date"],
        state=data["new_state"],
        upstream=data["include_upstream"],
        downstream=data["include_downstream"],
        future=data["include_future"],
        past=data["include_past"],
        commit=not data["dry_run"],
        session=session,
    )
    return task_instance_reference_collection_schema.dump(
        TaskInstanceReferenceCollection(task_instances=tis))
Beispiel #11
0
def post_dag_run(dag_id, session):
    """Trigger a DAG."""
    if not session.query(DagModel).filter(DagModel.dag_id == dag_id).first():
        raise NotFound(title="DAG not found",
                       detail=f"DAG with dag_id: '{dag_id}' not found")
    try:
        post_body = dagrun_schema.load(request.json, session=session)
    except ValidationError as err:
        raise BadRequest(detail=str(err))

    dagrun_instance = (session.query(DagRun).filter(
        DagRun.dag_id == dag_id,
        or_(DagRun.run_id == post_body["run_id"],
            DagRun.execution_date == post_body["execution_date"]),
    ).first())
    if not dagrun_instance:
        dag_run = DagRun(dag_id=dag_id,
                         run_type=DagRunType.MANUAL,
                         **post_body)
        session.add(dag_run)
        session.commit()
        return dagrun_schema.dump(dag_run)

    if dagrun_instance.execution_date == post_body["execution_date"]:
        raise AlreadyExists(
            detail=f"DAGRun with DAG ID: '{dag_id}' and "
            f"DAGRun ExecutionDate: '{post_body['execution_date']}' already exists"
        )

    raise AlreadyExists(
        detail=
        f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: '{post_body['run_id']}' already exists"
    )
def update_dag_run_state(*,
                         dag_id: str,
                         dag_run_id: str,
                         session: Session = NEW_SESSION) -> APIResponse:
    """Set a state of a dag run."""
    dag_run: Optional[DagRun] = (session.query(DagRun).filter(
        DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none())
    if dag_run is None:
        error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}'
        raise NotFound(error_message)
    try:
        post_body = set_dagrun_state_form_schema.load(get_json_request_dict())
    except ValidationError as err:
        raise BadRequest(detail=str(err))

    state = post_body['state']
    dag = get_airflow_app().dag_bag.get_dag(dag_id)
    if state == DagRunState.SUCCESS:
        set_dag_run_state_to_success(dag=dag,
                                     run_id=dag_run.run_id,
                                     commit=True)
    elif state == DagRunState.QUEUED:
        set_dag_run_state_to_queued(dag=dag,
                                    run_id=dag_run.run_id,
                                    commit=True)
    else:
        set_dag_run_state_to_failed(dag=dag,
                                    run_id=dag_run.run_id,
                                    commit=True)
    dag_run = session.query(DagRun).get(dag_run.id)
    return dagrun_schema.dump(dag_run)
Beispiel #13
0
def post_clear_task_instances(dag_id: str, session=None):
    """
    Clear task instances.
    """
    body = request.get_json()
    try:
        data = clear_task_instance_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    dag = current_app.dag_bag.get_dag(dag_id)
    if not dag:
        error_message = "Dag id {} not found".format(dag_id)
        raise NotFound(error_message)
    reset_dag_runs = data.pop('reset_dag_runs')
    task_instances = dag.clear(get_tis=True, **data)
    if not data["dry_run"]:
        clear_task_instances(
            task_instances,
            session,
            dag=dag,
            activate_dag_runs=False,  # We will set DagRun state later.
        )
        if reset_dag_runs:
            dag.set_dag_runs_state(
                session=session,
                start_date=data["start_date"],
                end_date=data["end_date"],
                state=State.RUNNING,
            )
    task_instances = task_instances.join(
        DR, and_(DR.dag_id == TI.dag_id,
                 DR.execution_date == TI.execution_date)).add_column(DR.run_id)
    return task_instance_reference_collection_schema.dump(
        TaskInstanceReferenceCollection(task_instances=task_instances.all()))
Beispiel #14
0
def delete_pool(pool_name: str, session):
    """Delete a pool"""
    if pool_name == "default_pool":
        raise BadRequest(detail="Default Pool can't be deleted")
    elif session.query(Pool).filter(Pool.pool == pool_name).delete() == 0:
        raise NotFound(detail=f"Pool with name:'{pool_name}' not found")
    else:
        return Response(status=204)
Beispiel #15
0
def post_user():
    """Create a new user"""
    try:
        data = user_schema.load(request.json)
    except ValidationError as e:
        raise BadRequest(detail=str(e.messages))

    security_manager = current_app.appbuilder.sm
    username = data["username"]
    email = data["email"]

    if security_manager.find_user(username=username):
        detail = f"Username `{username}` already exists. Use PATCH to update."
        raise AlreadyExists(detail=detail)
    if security_manager.find_user(email=email):
        detail = f"The email `{email}` is already taken."
        raise AlreadyExists(detail=detail)

    roles_to_add = []
    missing_role_names = []
    for role_data in data.pop("roles", ()):
        role_name = role_data["name"]
        role = security_manager.find_role(role_name)
        if role is None:
            missing_role_names.append(role_name)
        else:
            roles_to_add.append(role)
    if missing_role_names:
        detail = f"Unknown roles: {', '.join(repr(n) for n in missing_role_names)}"
        raise BadRequest(detail=detail)

    if roles_to_add:
        default_role = roles_to_add.pop()
    else:  # No roles provided, use the F.A.B's default registered user role.
        default_role = security_manager.find_role(
            security_manager.auth_user_registration_role)

    user = security_manager.add_user(role=default_role, **data)
    if not user:
        detail = f"Failed to add user `{username}`."
        return Unknown(detail=detail)

    if roles_to_add:
        user.roles.extend(roles_to_add)
        security_manager.update_user(user)
    return user_schema.dump(user)
Beispiel #16
0
def post_dag_run(dag_id, session):
    """Trigger a DAG."""
    if not session.query(DagModel).filter(DagModel.dag_id == dag_id).first():
        raise NotFound(title="DAG not found", detail=f"DAG with dag_id: '{dag_id}' not found")
    try:
        post_body = dagrun_schema.load(request.json, session=session)
    except ValidationError as err:
        raise BadRequest(detail=str(err))

    logical_date = pendulum.instance(post_body["execution_date"])
    run_id = post_body["run_id"]
    dagrun_instance = (
        session.query(DagRun)
        .filter(
            DagRun.dag_id == dag_id,
            or_(DagRun.run_id == run_id, DagRun.execution_date == logical_date),
        )
        .first()
    )
    if not dagrun_instance:
        try:
            dag = current_app.dag_bag.get_dag(dag_id)
            dag_run = dag.create_dagrun(
                run_type=DagRunType.MANUAL,
                run_id=run_id,
                execution_date=logical_date,
                data_interval=dag.timetable.infer_manual_data_interval(run_after=logical_date),
                state=State.QUEUED,
                conf=post_body.get("conf"),
                external_trigger=True,
                dag_hash=current_app.dag_bag.dags_hash.get(dag_id),
            )
            return dagrun_schema.dump(dag_run)
        except ValueError as ve:
            raise BadRequest(detail=str(ve))

    if dagrun_instance.execution_date == logical_date:
        raise AlreadyExists(
            detail=(
                f"DAGRun with DAG ID: '{dag_id}' and "
                f"DAGRun logical date: '{logical_date.isoformat(sep=' ')}' already exists"
            ),
        )

    raise AlreadyExists(detail=f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: '{run_id}' already exists")
def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
    """Clear task instances."""
    body = get_json_request_dict()
    try:
        data = clear_task_instance_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    dag = get_airflow_app().dag_bag.get_dag(dag_id)
    if not dag:
        error_message = f"Dag id {dag_id} not found"
        raise NotFound(error_message)
    reset_dag_runs = data.pop('reset_dag_runs')
    dry_run = data.pop('dry_run')
    # We always pass dry_run here, otherwise this would try to confirm on the terminal!
    dag_run_id = data.pop('dag_run_id', None)
    future = data.pop('include_future', False)
    past = data.pop('include_past', False)
    downstream = data.pop('include_downstream', False)
    upstream = data.pop('include_upstream', False)
    if dag_run_id is not None:
        dag_run: Optional[DR] = (
            session.query(DR).filter(DR.dag_id == dag_id, DR.run_id == dag_run_id).one_or_none()
        )
        if dag_run is None:
            error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}'
            raise NotFound(error_message)
        data['start_date'] = dag_run.logical_date
        data['end_date'] = dag_run.logical_date
    if past:
        data['start_date'] = None
    if future:
        data['end_date'] = None
    task_ids = data.pop('task_ids', None)
    if task_ids is not None:
        task_id = [task[0] if isinstance(task, tuple) else task for task in task_ids]
        dag = dag.partial_subset(
            task_ids_or_regex=task_id,
            include_downstream=downstream,
            include_upstream=upstream,
        )

        if len(dag.task_dict) > 1:
            # If we had upstream/downstream etc then also include those!
            task_ids.extend(tid for tid in dag.task_dict if tid != task_id)
    task_instances = dag.clear(dry_run=True, dag_bag=get_airflow_app().dag_bag, task_ids=task_ids, **data)
    if not dry_run:
        clear_task_instances(
            task_instances.all(),
            session,
            dag=dag,
            dag_run_state=DagRunState.QUEUED if reset_dag_runs else False,
        )

    return task_instance_reference_collection_schema.dump(
        TaskInstanceReferenceCollection(task_instances=task_instances.all())
    )
Beispiel #18
0
def patch_dag(*, dag_id: str, update_mask: UpdateMask = None, session: Session = NEW_SESSION) -> APIResponse:
    """Update the specific DAG"""
    try:
        patch_body = dag_schema.load(request.json, session=session)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))
    if update_mask:
        patch_body_ = {}
        if update_mask != ['is_paused']:
            raise BadRequest(detail="Only `is_paused` field can be updated through the REST API")
        patch_body_[update_mask[0]] = patch_body[update_mask[0]]
        patch_body = patch_body_
    dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one_or_none()
    if not dag:
        raise NotFound(f"Dag with id: '{dag_id}' not found")
    dag.is_paused = patch_body['is_paused']
    session.flush()
    return dag_schema.dump(dag)
Beispiel #19
0
def patch_variable(*, variable_key: str, update_mask: UpdateMask = None) -> Response:
    """Update a variable by key"""
    try:
        data = variable_schema.load(get_json_request_dict())
    except ValidationError as err:
        raise BadRequest("Invalid Variable schema", detail=str(err.messages))

    if data["key"] != variable_key:
        raise BadRequest("Invalid post body", detail="key from request body doesn't match uri parameter")

    if update_mask:
        if "key" in update_mask:
            raise BadRequest("key is a ready only field")
        if "value" not in update_mask:
            raise BadRequest("No field to update")

    Variable.set(data["key"], data["val"])
    return variable_schema.dump(data)
Beispiel #20
0
def post_variables() -> Response:
    """Create a variable"""
    try:
        data = variable_schema.load(request.json)

    except ValidationError as err:
        raise BadRequest("Invalid Variable schema", detail=str(err.messages))
    Variable.set(data["key"], data["val"])
    return variable_schema.dump(data)
Beispiel #21
0
def post_set_task_instances_state(*,
                                  dag_id: str,
                                  session: Session = NEW_SESSION
                                  ) -> APIResponse:
    """Set a state of task instances."""
    body = request.get_json()
    try:
        data = set_task_instance_state_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    error_message = f"Dag ID {dag_id} not found"
    dag = current_app.dag_bag.get_dag(dag_id)
    if not dag:
        raise NotFound(error_message)

    task_id = data['task_id']
    task = dag.task_dict.get(task_id)

    if not task:
        error_message = f"Task ID {task_id} not found"
        raise NotFound(error_message)

    execution_date = data.get('execution_date')
    run_id = data.get('dag_run_id')
    if (execution_date and (session.query(TI).filter(
            TI.task_id == task_id, TI.dag_id == dag_id, TI.execution_date
            == execution_date).one_or_none()) is None):
        raise NotFound(
            detail=
            f"Task instance not found for task {task_id!r} on execution_date {execution_date}"
        )

    if run_id and not session.query(TI).get({
            'task_id': task_id,
            'dag_id': dag_id,
            'run_id': run_id,
            'map_index': -1
    }):
        error_message = f"Task instance not found for task {task_id!r} on DAG run with ID {run_id!r}"
        raise NotFound(detail=error_message)

    tis = dag.set_task_instance_state(
        task_id=task_id,
        dag_run_id=run_id,
        execution_date=execution_date,
        state=data["new_state"],
        upstream=data["include_upstream"],
        downstream=data["include_downstream"],
        future=data["include_future"],
        past=data["include_past"],
        commit=not data["dry_run"],
        session=session,
    )
    return task_instance_reference_collection_schema.dump(
        TaskInstanceReferenceCollection(task_instances=tis))
Beispiel #22
0
def post_pool(session):
    """Create a pool"""
    required_fields = ["name", "slots"
                       ]  # Pool would require both fields in the post request
    for field in required_fields:
        if field not in request.json.keys():
            raise BadRequest(detail=f"'{field}' is a required property")

    try:
        post_body = pool_schema.load(request.json, session=session)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    pool = Pool(**post_body)
    try:
        session.add(pool)
        session.commit()
        return pool_schema.dump(pool)
    except IntegrityError:
        raise AlreadyExists(detail=f"Pool: {post_body['pool']} already exists")
Beispiel #23
0
def delete_pool(*,
                pool_name: str,
                session: Session = NEW_SESSION) -> APIResponse:
    """Delete a pool"""
    if pool_name == "default_pool":
        raise BadRequest(detail="Default Pool can't be deleted")
    affected_count = session.query(Pool).filter(
        Pool.pool == pool_name).delete()
    if affected_count == 0:
        raise NotFound(detail=f"Pool with name:'{pool_name}' not found")
    return Response(status=HTTPStatus.NO_CONTENT)
Beispiel #24
0
def patch_variable(variable_key: str, update_mask: Optional[List[str]] = None) -> Response:
    """
    Update a variable by key
    """
    try:
        data = variable_schema.load(request.json)
    except ValidationError as err:
        raise BadRequest("Invalid Variable schema", detail=str(err.messages))

    if data["key"] != variable_key:
        raise BadRequest("Invalid post body", detail="key from request body doesn't match uri parameter")

    if update_mask:
        if "key" in update_mask:
            raise BadRequest("key is a ready only field")
        if "value" not in update_mask:
            raise BadRequest("No field to update")

    Variable.set(data["key"], data["val"])
    return Response(status=204)
Beispiel #25
0
def patch_dag(session, dag_id, update_mask=None):
    """Update the specific DAG"""
    dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one_or_none()
    if not dag:
        raise NotFound(f"Dag with id: '{dag_id}' not found")
    try:
        patch_body = dag_schema.load(request.json, session=session)
    except ValidationError as err:
        raise BadRequest("Invalid Dag schema", detail=str(err.messages))
    if update_mask:
        patch_body_ = {}
        if len(update_mask) > 1:
            raise BadRequest(detail="Only `is_paused` field can be updated through the REST API")
        update_mask = update_mask[0]
        if update_mask != 'is_paused':
            raise BadRequest(detail="Only `is_paused` field can be updated through the REST API")
        patch_body_[update_mask] = patch_body[update_mask]
        patch_body = patch_body_
    setattr(dag, 'is_paused', patch_body['is_paused'])
    session.commit()
    return dag_schema.dump(dag)
def post_pool(session):
    """Create a pool"""
    required_fields = {"name", "slots"
                       }  # Pool would require both fields in the post request
    fields_diff = required_fields - set(request.json.keys())
    if fields_diff:
        raise BadRequest(
            detail=f"Missing required property(ies): {sorted(fields_diff)}")

    try:
        post_body = pool_schema.load(request.json, session=session)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    pool = Pool(**post_body)
    try:
        session.add(pool)
        session.commit()
        return pool_schema.dump(pool)
    except IntegrityError:
        raise AlreadyExists(detail=f"Pool: {post_body['pool']} already exists")
Beispiel #27
0
 def autogenerate(self, data, **kwargs):
     """Auto generate run_id and execution_date if they are not loaded"""
     if "execution_date" not in data.keys():
         data["execution_date"] = str(timezone.utcnow())
     if "dag_run_id" not in data.keys():
         try:
             data["dag_run_id"] = DagRun.generate_run_id(
                 DagRunType.MANUAL, timezone.parse(data["execution_date"])
             )
         except (ParserError, TypeError) as err:
             raise BadRequest("Incorrect datetime argument", detail=str(err))
     return data
Beispiel #28
0
def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pattern=None, update_mask=None):
    """Patch multiple DAGs."""
    try:
        patch_body = dag_schema.load(request.json, session=session)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))
    if update_mask:
        patch_body_ = {}
        if update_mask != ['is_paused']:
            raise BadRequest(detail="Only `is_paused` field can be updated through the REST API")
        update_mask = update_mask[0]
        patch_body_[update_mask] = patch_body[update_mask]
        patch_body = patch_body_
    if only_active:
        dags_query = session.query(DagModel).filter(~DagModel.is_subdag, DagModel.is_active)
    else:
        dags_query = session.query(DagModel).filter(~DagModel.is_subdag)

    if dag_id_pattern == '~':
        dag_id_pattern = '%'
    dags_query = dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%'))
    editable_dags = current_app.appbuilder.sm.get_editable_dag_ids(g.user)

    dags_query = dags_query.filter(DagModel.dag_id.in_(editable_dags))
    if tags:
        cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags]
        dags_query = dags_query.filter(or_(*cond))

    total_entries = dags_query.count()

    dags = dags_query.order_by(DagModel.dag_id).offset(offset).limit(limit).all()

    dags_to_update = {dag.dag_id for dag in dags}
    session.query(DagModel).filter(DagModel.dag_id.in_(dags_to_update)).update(
        {DagModel.is_paused: patch_body['is_paused']}, synchronize_session='fetch'
    )

    session.flush()

    return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries))
Beispiel #29
0
def format_datetime(value: str):
    """
    Datetime format parser for args since connexion doesn't parse datetimes
    https://github.com/zalando/connexion/issues/476

    This should only be used within connection views because it raises 400
    """
    if value[-1] != 'Z':
        value = value.replace(" ", '+')
    try:
        return timezone.parse(value)
    except (ParserError, TypeError) as err:
        raise BadRequest("Incorrect datetime argument", detail=str(err))
Beispiel #30
0
def patch_connection(
    *,
    connection_id: str,
    update_mask: UpdateMask = None,
    session: Session = NEW_SESSION,
) -> APIResponse:
    """Update a connection entry"""
    try:
        data = connection_schema.load(request.json, partial=True)
    except ValidationError as err:
        # If validation get to here, it is extra field validation.
        raise BadRequest(detail=str(err.messages))
    non_update_fields = ['connection_id', 'conn_id']
    connection = session.query(Connection).filter_by(
        conn_id=connection_id).first()
    if connection is None:
        raise NotFound(
            "Connection not found",
            detail=
            f"The Connection with connection_id: `{connection_id}` was not found",
        )
    if data.get('conn_id') and connection.conn_id != data['conn_id']:
        raise BadRequest(detail="The connection_id cannot be updated.")
    if update_mask:
        update_mask = [i.strip() for i in update_mask]
        data_ = {}
        for field in update_mask:
            if field in data and field not in non_update_fields:
                data_[field] = data[field]
            else:
                raise BadRequest(
                    detail=f"'{field}' is unknown or cannot be updated.")
        data = data_
    for key in data:
        setattr(connection, key, data[key])
    session.add(connection)
    session.commit()
    return connection_schema.dump(connection)