Exemplo n.º 1
0
def post_clear_task_instances(*,
                              dag_id: str,
                              session: Session = NEW_SESSION) -> APIResponse:
    """Clear task instances."""
    body = get_json_request_dict()
    try:
        data = clear_task_instance_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    dag = get_airflow_app().dag_bag.get_dag(dag_id)
    if not dag:
        error_message = f"Dag id {dag_id} not found"
        raise NotFound(error_message)
    reset_dag_runs = data.pop('reset_dag_runs')
    dry_run = data.pop('dry_run')
    # We always pass dry_run here, otherwise this would try to confirm on the terminal!
    task_instances = dag.clear(dry_run=True,
                               dag_bag=get_airflow_app().dag_bag,
                               **data)
    if not dry_run:
        clear_task_instances(
            task_instances.all(),
            session,
            dag=dag,
            dag_run_state=DagRunState.QUEUED if reset_dag_runs else False,
        )

    return task_instance_reference_collection_schema.dump(
        TaskInstanceReferenceCollection(task_instances=task_instances.all()))
def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
    """Clear task instances."""
    body = get_json_request_dict()
    try:
        data = clear_task_instance_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    dag = get_airflow_app().dag_bag.get_dag(dag_id)
    if not dag:
        error_message = f"Dag id {dag_id} not found"
        raise NotFound(error_message)
    reset_dag_runs = data.pop('reset_dag_runs')
    dry_run = data.pop('dry_run')
    # We always pass dry_run here, otherwise this would try to confirm on the terminal!
    dag_run_id = data.pop('dag_run_id', None)
    future = data.pop('include_future', False)
    past = data.pop('include_past', False)
    downstream = data.pop('include_downstream', False)
    upstream = data.pop('include_upstream', False)
    if dag_run_id is not None:
        dag_run: Optional[DR] = (
            session.query(DR).filter(DR.dag_id == dag_id, DR.run_id == dag_run_id).one_or_none()
        )
        if dag_run is None:
            error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}'
            raise NotFound(error_message)
        data['start_date'] = dag_run.logical_date
        data['end_date'] = dag_run.logical_date
    if past:
        data['start_date'] = None
    if future:
        data['end_date'] = None
    task_ids = data.pop('task_ids', None)
    if task_ids is not None:
        task_id = [task[0] if isinstance(task, tuple) else task for task in task_ids]
        dag = dag.partial_subset(
            task_ids_or_regex=task_id,
            include_downstream=downstream,
            include_upstream=upstream,
        )

        if len(dag.task_dict) > 1:
            # If we had upstream/downstream etc then also include those!
            task_ids.extend(tid for tid in dag.task_dict if tid != task_id)
    task_instances = dag.clear(dry_run=True, dag_bag=get_airflow_app().dag_bag, task_ids=task_ids, **data)
    if not dry_run:
        clear_task_instances(
            task_instances.all(),
            session,
            dag=dag,
            dag_run_state=DagRunState.QUEUED if reset_dag_runs else False,
        )

    return task_instance_reference_collection_schema.dump(
        TaskInstanceReferenceCollection(task_instances=task_instances.all())
    )
Exemplo n.º 3
0
def post_dag_run(*,
                 dag_id: str,
                 session: Session = NEW_SESSION) -> APIResponse:
    """Trigger a DAG."""
    dm = session.query(DagModel).filter(DagModel.dag_id == dag_id).first()
    if not dm:
        raise NotFound(title="DAG not found",
                       detail=f"DAG with dag_id: '{dag_id}' not found")
    if dm.has_import_errors:
        raise BadRequest(
            title="DAG cannot be triggered",
            detail=f"DAG with dag_id: '{dag_id}' has import errors",
        )
    try:
        post_body = dagrun_schema.load(get_json_request_dict(),
                                       session=session)
    except ValidationError as err:
        raise BadRequest(detail=str(err))

    logical_date = pendulum.instance(post_body["execution_date"])
    run_id = post_body["run_id"]
    dagrun_instance = (session.query(DagRun).filter(
        DagRun.dag_id == dag_id,
        or_(DagRun.run_id == run_id, DagRun.execution_date == logical_date),
    ).first())
    if not dagrun_instance:
        try:
            dag = get_airflow_app().dag_bag.get_dag(dag_id)
            dag_run = dag.create_dagrun(
                run_type=DagRunType.MANUAL,
                run_id=run_id,
                execution_date=logical_date,
                data_interval=dag.timetable.infer_manual_data_interval(
                    run_after=logical_date),
                state=DagRunState.QUEUED,
                conf=post_body.get("conf"),
                external_trigger=True,
                dag_hash=get_airflow_app().dag_bag.dags_hash.get(dag_id),
            )
            return dagrun_schema.dump(dag_run)
        except ValueError as ve:
            raise BadRequest(detail=str(ve))

    if dagrun_instance.execution_date == logical_date:
        raise AlreadyExists(detail=(
            f"DAGRun with DAG ID: '{dag_id}' and "
            f"DAGRun logical date: '{logical_date.isoformat(sep=' ')}' already exists"
        ), )

    raise AlreadyExists(
        detail=
        f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: '{run_id}' already exists"
    )
Exemplo n.º 4
0
def update_dag_run_state(*,
                         dag_id: str,
                         dag_run_id: str,
                         session: Session = NEW_SESSION) -> APIResponse:
    """Set a state of a dag run."""
    dag_run: Optional[DagRun] = (session.query(DagRun).filter(
        DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none())
    if dag_run is None:
        error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}'
        raise NotFound(error_message)
    try:
        post_body = set_dagrun_state_form_schema.load(get_json_request_dict())
    except ValidationError as err:
        raise BadRequest(detail=str(err))

    state = post_body['state']
    dag = get_airflow_app().dag_bag.get_dag(dag_id)
    if state == DagRunState.SUCCESS:
        set_dag_run_state_to_success(dag=dag,
                                     run_id=dag_run.run_id,
                                     commit=True)
    elif state == DagRunState.QUEUED:
        set_dag_run_state_to_queued(dag=dag,
                                    run_id=dag_run.run_id,
                                    commit=True)
    else:
        set_dag_run_state_to_failed(dag=dag,
                                    run_id=dag_run.run_id,
                                    commit=True)
    dag_run = session.query(DagRun).get(dag_run.id)
    return dagrun_schema.dump(dag_run)
def patch_role(*,
               role_name: str,
               update_mask: UpdateMask = None) -> APIResponse:
    """Update a role"""
    appbuilder = get_airflow_app().appbuilder
    security_manager = appbuilder.sm
    body = request.json
    try:
        data = role_schema.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))
    role = security_manager.find_role(name=role_name)
    if not role:
        raise NotFound(title="Role not found",
                       detail=f"Role with name {role_name!r} was not found")
    if update_mask:
        update_mask = [i.strip() for i in update_mask]
        data_ = {}
        for field in update_mask:
            if field in data and not field == "permissions":
                data_[field] = data[field]
            elif field == "actions":
                data_["permissions"] = data['permissions']
            else:
                raise BadRequest(detail=f"'{field}' in update_mask is unknown")
        data = data_
    if "permissions" in data:
        perms = [(item["action"]["name"], item["resource"]["name"])
                 for item in data["permissions"] if item]
        _check_action_and_resource(security_manager, perms)
        security_manager.bulk_sync_roles([{"role": role_name, "perms": perms}])
    new_name = data.get("name")
    if new_name is not None and new_name != role.name:
        security_manager.update_role(role_id=role.id, name=new_name)
    return role_schema.dump(role)
Exemplo n.º 6
0
def get_dags(
    *,
    limit: int,
    offset: int = 0,
    tags: Optional[Collection[str]] = None,
    dag_id_pattern: Optional[str] = None,
    only_active: bool = True,
    session: Session = NEW_SESSION,
) -> APIResponse:
    """Get all DAGs."""
    if only_active:
        dags_query = session.query(DagModel).filter(~DagModel.is_subdag, DagModel.is_active)
    else:
        dags_query = session.query(DagModel).filter(~DagModel.is_subdag)

    if dag_id_pattern:
        dags_query = dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%'))

    readable_dags = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)

    dags_query = dags_query.filter(DagModel.dag_id.in_(readable_dags))
    if tags:
        cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags]
        dags_query = dags_query.filter(or_(*cond))

    total_entries = dags_query.count()

    dags = dags_query.order_by(DagModel.dag_id).offset(offset).limit(limit).all()

    return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries))
Exemplo n.º 7
0
def get_users(*,
              limit: int,
              order_by: str = "id",
              offset: Optional[str] = None) -> APIResponse:
    """Get users"""
    appbuilder = get_airflow_app().appbuilder
    session = appbuilder.get_session
    total_entries = session.query(func.count(User.id)).scalar()
    direction = desc if order_by.startswith("-") else asc
    to_replace = {"user_id": "id"}
    order_param = order_by.strip("-")
    order_param = to_replace.get(order_param, order_param)
    allowed_filter_attrs = [
        'id',
        "first_name",
        "last_name",
        "user_name",
        "email",
        "is_active",
        "role",
    ]
    if order_by not in allowed_filter_attrs:
        raise BadRequest(detail=f"Ordering with '{order_by}' is disallowed or "
                         f"the attribute does not exist on the model")

    query = session.query(User)
    users = query.order_by(direction(getattr(
        User, order_param))).offset(offset).limit(limit).all()

    return user_collection_schema.dump(
        UserCollection(users=users, total_entries=total_entries))
Exemplo n.º 8
0
def get_xcom_entries(
    *,
    dag_id: str,
    dag_run_id: str,
    task_id: str,
    limit: Optional[int],
    offset: Optional[int] = None,
    session: Session = NEW_SESSION,
) -> APIResponse:
    """Get all XCom values"""
    query = session.query(XCom)
    if dag_id == '~':
        appbuilder = get_airflow_app().appbuilder
        readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user)
        query = query.filter(XCom.dag_id.in_(readable_dag_ids))
        query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id))
    else:
        query = query.filter(XCom.dag_id == dag_id)
        query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id))

    if task_id != '~':
        query = query.filter(XCom.task_id == task_id)
    if dag_run_id != '~':
        query = query.filter(DR.run_id == dag_run_id)
    query = query.order_by(DR.execution_date, XCom.task_id, XCom.dag_id, XCom.key)
    total_entries = query.count()
    query = query.offset(offset).limit(limit)
    return xcom_collection_schema.dump(XComCollection(xcom_entries=query.all(), total_entries=total_entries))
def get_role(*, role_name: str) -> APIResponse:
    """Get role"""
    ab_security_manager = get_airflow_app().appbuilder.sm
    role = ab_security_manager.find_role(name=role_name)
    if not role:
        raise NotFound(title="Role not found",
                       detail=f"Role with name {role_name!r} was not found")
    return role_schema.dump(role)
Exemplo n.º 10
0
def check_authentication() -> None:
    """Checks that the request has valid authorization information."""
    for auth in get_airflow_app().api_auth:
        response = auth.requires_authentication(Response)()
        if response.status_code == 200:
            return
    # since this handler only checks authentication, not authorization,
    # we should always return 401
    raise Unauthenticated(headers=response.headers)
Exemplo n.º 11
0
def get_user(*, username: str) -> APIResponse:
    """Get a user"""
    ab_security_manager = get_airflow_app().appbuilder.sm
    user = ab_security_manager.find_user(username=username)
    if not user:
        raise NotFound(
            title="User not found",
            detail=f"The User with username `{username}` was not found")
    return user_collection_item_schema.dump(user)
def delete_role(*, role_name: str) -> APIResponse:
    """Delete a role"""
    ab_security_manager = get_airflow_app().appbuilder.sm
    role = ab_security_manager.find_role(name=role_name)
    if not role:
        raise NotFound(title="Role not found",
                       detail=f"Role with name {role_name!r} was not found")
    ab_security_manager.delete_role(role_name=role_name)
    return NoContent, HTTPStatus.NO_CONTENT
Exemplo n.º 13
0
def post_set_task_instances_state(*,
                                  dag_id: str,
                                  session: Session = NEW_SESSION
                                  ) -> APIResponse:
    """Set a state of task instances."""
    body = get_json_request_dict()
    try:
        data = set_task_instance_state_form.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    error_message = f"Dag ID {dag_id} not found"
    dag = get_airflow_app().dag_bag.get_dag(dag_id)
    if not dag:
        raise NotFound(error_message)

    task_id = data['task_id']
    task = dag.task_dict.get(task_id)

    if not task:
        error_message = f"Task ID {task_id} not found"
        raise NotFound(error_message)

    execution_date = data.get('execution_date')
    run_id = data.get('dag_run_id')
    if (execution_date and (session.query(TI).filter(
            TI.task_id == task_id, TI.dag_id == dag_id, TI.execution_date
            == execution_date).one_or_none()) is None):
        raise NotFound(
            detail=
            f"Task instance not found for task {task_id!r} on execution_date {execution_date}"
        )

    if run_id and not session.query(TI).get({
            'task_id': task_id,
            'dag_id': dag_id,
            'run_id': run_id,
            'map_index': -1
    }):
        error_message = f"Task instance not found for task {task_id!r} on DAG run with ID {run_id!r}"
        raise NotFound(detail=error_message)

    tis = dag.set_task_instance_state(
        task_id=task_id,
        run_id=run_id,
        execution_date=execution_date,
        state=data["new_state"],
        upstream=data["include_upstream"],
        downstream=data["include_downstream"],
        future=data["include_future"],
        past=data["include_past"],
        commit=not data["dry_run"],
        session=session,
    )
    return task_instance_reference_collection_schema.dump(
        TaskInstanceReferenceCollection(task_instances=tis))
def get_permissions(*,
                    limit: int,
                    offset: Optional[int] = None) -> APIResponse:
    """Get permissions"""
    session = get_airflow_app().appbuilder.get_session
    total_entries = session.query(func.count(Action.id)).scalar()
    query = session.query(Action)
    actions = query.offset(offset).limit(limit).all()
    return action_collection_schema.dump(
        ActionCollection(actions=actions, total_entries=total_entries))
Exemplo n.º 15
0
def delete_user(*, username: str) -> APIResponse:
    """Delete a user"""
    security_manager = get_airflow_app().appbuilder.sm

    user = security_manager.find_user(username=username)
    if user is None:
        detail = f"The User with username `{username}` was not found"
        raise NotFound(title="User not found", detail=detail)

    user.roles = []  # Clear foreign keys on this user first.
    security_manager.get_session.delete(user)
    security_manager.get_session.commit()

    return NoContent, HTTPStatus.NO_CONTENT
Exemplo n.º 16
0
def auth_current_user() -> Optional[User]:
    """Authenticate and set current user if Authorization header exists"""
    auth = request.authorization
    if auth is None or not auth.username or not auth.password:
        return None

    ab_security_manager = get_airflow_app().appbuilder.sm
    user = None
    if ab_security_manager.auth_type == AUTH_LDAP:
        user = ab_security_manager.auth_user_ldap(auth.username, auth.password)
    if user is None:
        user = ab_security_manager.auth_user_db(auth.username, auth.password)
    if user is not None:
        login_user(user, remember=False)
    return user
Exemplo n.º 17
0
def requires_access(permissions: Optional[Sequence[Tuple[str, str]]] = None) -> Callable[[T], T]:
    """Factory for decorator that checks current user's permissions against required permissions."""
    appbuilder = get_airflow_app().appbuilder
    appbuilder.sm.sync_resource_permissions(permissions)

    def requires_access_decorator(func: T):
        @wraps(func)
        def decorated(*args, **kwargs):
            check_authentication()
            if appbuilder.sm.check_authorization(permissions, kwargs.get('dag_id')):
                return func(*args, **kwargs)
            raise PermissionDenied()

        return cast(T, decorated)

    return requires_access_decorator
Exemplo n.º 18
0
def clear_dag_run(*,
                  dag_id: str,
                  dag_run_id: str,
                  session: Session = NEW_SESSION) -> APIResponse:
    """Clear a dag run."""
    dag_run: Optional[DagRun] = (session.query(DagRun).filter(
        DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none())
    if dag_run is None:
        error_message = f'Dag Run id {dag_run_id} not found in dag   {dag_id}'
        raise NotFound(error_message)
    try:
        post_body = clear_dagrun_form_schema.load(get_json_request_dict())
    except ValidationError as err:
        raise BadRequest(detail=str(err))

    dry_run = post_body.get('dry_run', False)
    dag = get_airflow_app().dag_bag.get_dag(dag_id)
    start_date = dag_run.logical_date
    end_date = dag_run.logical_date

    if dry_run:
        task_instances = dag.clear(
            start_date=start_date,
            end_date=end_date,
            task_ids=None,
            include_subdags=True,
            include_parentdag=True,
            only_failed=False,
            dry_run=True,
        )
        return task_instance_reference_collection_schema.dump(
            TaskInstanceReferenceCollection(task_instances=task_instances))
    else:
        dag.clear(
            start_date=start_date,
            end_date=end_date,
            task_ids=None,
            include_subdags=True,
            include_parentdag=True,
            only_failed=False,
        )
        dag_run.refresh_from_db()
        return dagrun_schema.dump(dag_run)
Exemplo n.º 19
0
def get_dag_runs(
    *,
    dag_id: str,
    start_date_gte: Optional[str] = None,
    start_date_lte: Optional[str] = None,
    execution_date_gte: Optional[str] = None,
    execution_date_lte: Optional[str] = None,
    end_date_gte: Optional[str] = None,
    end_date_lte: Optional[str] = None,
    state: Optional[List[str]] = None,
    offset: Optional[int] = None,
    limit: Optional[int] = None,
    order_by: str = "id",
    session: Session = NEW_SESSION,
):
    """Get all DAG Runs."""
    query = session.query(DagRun)

    #  This endpoint allows specifying ~ as the dag_id to retrieve DAG Runs for all DAGs.
    if dag_id == "~":
        appbuilder = get_airflow_app().appbuilder
        query = query.filter(
            DagRun.dag_id.in_(appbuilder.sm.get_readable_dag_ids(g.user)))
    else:
        query = query.filter(DagRun.dag_id == dag_id)

    if state:
        query = query.filter(DagRun.state.in_(state))

    dag_run, total_entries = _fetch_dag_runs(
        query,
        end_date_gte=end_date_gte,
        end_date_lte=end_date_lte,
        execution_date_gte=execution_date_gte,
        execution_date_lte=execution_date_lte,
        start_date_gte=start_date_gte,
        start_date_lte=start_date_lte,
        limit=limit,
        offset=offset,
        order_by=order_by,
    )
    return dagrun_collection_schema.dump(
        DAGRunCollection(dag_runs=dag_run, total_entries=total_entries))
Exemplo n.º 20
0
def post_user() -> APIResponse:
    """Create a new user"""
    try:
        data = user_schema.load(request.json)
    except ValidationError as e:
        raise BadRequest(detail=str(e.messages))

    security_manager = get_airflow_app().appbuilder.sm
    username = data["username"]
    email = data["email"]

    if security_manager.find_user(username=username):
        detail = f"Username `{username}` already exists. Use PATCH to update."
        raise AlreadyExists(detail=detail)
    if security_manager.find_user(email=email):
        detail = f"The email `{email}` is already taken."
        raise AlreadyExists(detail=detail)

    roles_to_add = []
    missing_role_names = []
    for role_data in data.pop("roles", ()):
        role_name = role_data["name"]
        role = security_manager.find_role(role_name)
        if role is None:
            missing_role_names.append(role_name)
        else:
            roles_to_add.append(role)
    if missing_role_names:
        detail = f"Unknown roles: {', '.join(repr(n) for n in missing_role_names)}"
        raise BadRequest(detail=detail)

    if not roles_to_add:  # No roles provided, use the F.A.B's default registered user role.
        roles_to_add.append(
            security_manager.find_role(
                security_manager.auth_user_registration_role))

    user = security_manager.add_user(role=roles_to_add, **data)
    if not user:
        detail = f"Failed to add user `{username}`."
        return Unknown(detail=detail)

    return user_schema.dump(user)
def post_role() -> APIResponse:
    """Create a new role"""
    appbuilder = get_airflow_app().appbuilder
    security_manager = appbuilder.sm
    body = request.json
    try:
        data = role_schema.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))
    role = security_manager.find_role(name=data['name'])
    if not role:
        perms = [(item['action']['name'], item['resource']['name'])
                 for item in data['permissions'] if item]
        _check_action_and_resource(security_manager, perms)
        security_manager.bulk_sync_roles([{
            "role": data["name"],
            "perms": perms
        }])
        return role_schema.dump(role)
    detail = f"Role with name {role.name!r} already exists; please update with the PATCH endpoint"
    raise AlreadyExists(detail=detail)
Exemplo n.º 22
0
def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pattern=None, update_mask=None):
    """Patch multiple DAGs."""
    try:
        patch_body = dag_schema.load(request.json, session=session)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))
    if update_mask:
        patch_body_ = {}
        if update_mask != ['is_paused']:
            raise BadRequest(detail="Only `is_paused` field can be updated through the REST API")
        update_mask = update_mask[0]
        patch_body_[update_mask] = patch_body[update_mask]
        patch_body = patch_body_
    if only_active:
        dags_query = session.query(DagModel).filter(~DagModel.is_subdag, DagModel.is_active)
    else:
        dags_query = session.query(DagModel).filter(~DagModel.is_subdag)

    if dag_id_pattern == '~':
        dag_id_pattern = '%'
    dags_query = dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%'))
    editable_dags = get_airflow_app().appbuilder.sm.get_editable_dag_ids(g.user)

    dags_query = dags_query.filter(DagModel.dag_id.in_(editable_dags))
    if tags:
        cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags]
        dags_query = dags_query.filter(or_(*cond))

    total_entries = dags_query.count()

    dags = dags_query.order_by(DagModel.dag_id).offset(offset).limit(limit).all()

    dags_to_update = {dag.dag_id for dag in dags}
    session.query(DagModel).filter(DagModel.dag_id.in_(dags_to_update)).update(
        {DagModel.is_paused: patch_body['is_paused']}, synchronize_session='fetch'
    )

    session.flush()

    return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries))
Exemplo n.º 23
0
def get_extra_links(
    *,
    dag_id: str,
    dag_run_id: str,
    task_id: str,
    session: Session = NEW_SESSION,
) -> APIResponse:
    """Get extra links for task instance"""
    from airflow.models.taskinstance import TaskInstance

    dagbag: DagBag = get_airflow_app().dag_bag
    dag: DAG = dagbag.get_dag(dag_id)
    if not dag:
        raise NotFound("DAG not found",
                       detail=f'DAG with ID = "{dag_id}" not found')

    try:
        task = dag.get_task(task_id)
    except TaskNotFound:
        raise NotFound("Task not found",
                       detail=f'Task with ID = "{task_id}" not found')

    ti = (session.query(TaskInstance).filter(
        TaskInstance.dag_id == dag_id,
        TaskInstance.run_id == dag_run_id,
        TaskInstance.task_id == task_id,
    ).one_or_none())

    if not ti:
        raise NotFound("DAG Run not found",
                       detail=f'DAG Run with ID = "{dag_run_id}" not found')

    all_extra_link_pairs = ((link_name, task.get_extra_links(ti, link_name))
                            for link_name in task.extra_links)
    all_extra_links = {
        link_name: link_url if link_url else None
        for link_name, link_url in sorted(all_extra_link_pairs)
    }
    return all_extra_links
def get_roles(*,
              order_by: str = "name",
              limit: int,
              offset: Optional[int] = None) -> APIResponse:
    """Get roles"""
    appbuilder = get_airflow_app().appbuilder
    session = appbuilder.get_session
    total_entries = session.query(func.count(Role.id)).scalar()
    direction = desc if order_by.startswith("-") else asc
    to_replace = {"role_id": "id"}
    order_param = order_by.strip("-")
    order_param = to_replace.get(order_param, order_param)
    allowed_filter_attrs = ["role_id", "name"]
    if order_by not in allowed_filter_attrs:
        raise BadRequest(detail=f"Ordering with '{order_by}' is disallowed or "
                         f"the attribute does not exist on the model")

    query = session.query(Role)
    roles = query.order_by(direction(getattr(
        Role, order_param))).offset(offset).limit(limit).all()

    return role_collection_schema.dump(
        RoleCollection(roles=roles, total_entries=total_entries))
Exemplo n.º 25
0
def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse:
    """Get list of DAG Runs"""
    body = get_json_request_dict()
    try:
        data = dagruns_batch_form_schema.load(body)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    appbuilder = get_airflow_app().appbuilder
    readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user)
    query = session.query(DagRun)
    if data.get("dag_ids"):
        dag_ids = set(data["dag_ids"]) & set(readable_dag_ids)
        query = query.filter(DagRun.dag_id.in_(dag_ids))
    else:
        query = query.filter(DagRun.dag_id.in_(readable_dag_ids))

    states = data.get("states")
    if states:
        query = query.filter(DagRun.state.in_(states))

    dag_runs, total_entries = _fetch_dag_runs(
        query,
        end_date_gte=data["end_date_gte"],
        end_date_lte=data["end_date_lte"],
        execution_date_gte=data["execution_date_gte"],
        execution_date_lte=data["execution_date_lte"],
        start_date_gte=data["start_date_gte"],
        start_date_lte=data["start_date_lte"],
        limit=data["page_limit"],
        offset=data["page_offset"],
        order_by=data.get("order_by", "id"),
    )

    return dagrun_collection_schema.dump(
        DAGRunCollection(dag_runs=dag_runs, total_entries=total_entries))
def get_mapped_task_instances(
    *,
    dag_id: str,
    dag_run_id: str,
    task_id: str,
    execution_date_gte: Optional[str] = None,
    execution_date_lte: Optional[str] = None,
    start_date_gte: Optional[str] = None,
    start_date_lte: Optional[str] = None,
    end_date_gte: Optional[str] = None,
    end_date_lte: Optional[str] = None,
    duration_gte: Optional[float] = None,
    duration_lte: Optional[float] = None,
    state: Optional[List[str]] = None,
    pool: Optional[List[str]] = None,
    queue: Optional[List[str]] = None,
    limit: Optional[int] = None,
    offset: Optional[int] = None,
    order_by: Optional[str] = None,
    session: Session = NEW_SESSION,
) -> APIResponse:
    """Get list of task instances."""
    # Because state can be 'none'
    states = _convert_state(state)

    base_query = (
        session.query(TI)
        .filter(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == task_id, TI.map_index >= 0)
        .join(TI.dag_run)
    )

    # 0 can mean a mapped TI that expanded to an empty list, so it is not an automatic 404
    if base_query.with_entities(func.count('*')).scalar() == 0:
        dag = get_airflow_app().dag_bag.get_dag(dag_id)
        if not dag:
            error_message = f"DAG {dag_id} not found"
            raise NotFound(error_message)
        task = dag.get_task(task_id)
        if not task:
            error_message = f"Task id {task_id} not found"
            raise NotFound(error_message)
        if not task.is_mapped:
            error_message = f"Task id {task_id} is not mapped"
            raise NotFound(error_message)

    # Other search criteria
    query = _apply_range_filter(
        base_query,
        key=DR.execution_date,
        value_range=(execution_date_gte, execution_date_lte),
    )
    query = _apply_range_filter(query, key=TI.start_date, value_range=(start_date_gte, start_date_lte))
    query = _apply_range_filter(query, key=TI.end_date, value_range=(end_date_gte, end_date_lte))
    query = _apply_range_filter(query, key=TI.duration, value_range=(duration_gte, duration_lte))
    query = _apply_array_filter(query, key=TI.state, values=states)
    query = _apply_array_filter(query, key=TI.pool, values=pool)
    query = _apply_array_filter(query, key=TI.queue, values=queue)

    # Count elements before joining extra columns
    total_entries = query.with_entities(func.count('*')).scalar()

    # Add SLA miss
    query = (
        query.join(
            SlaMiss,
            and_(
                SlaMiss.dag_id == TI.dag_id,
                SlaMiss.task_id == TI.task_id,
                SlaMiss.execution_date == DR.execution_date,
            ),
            isouter=True,
        )
        .add_entity(SlaMiss)
        .options(joinedload(TI.rendered_task_instance_fields))
    )

    if order_by:
        if order_by == 'state':
            query = query.order_by(TI.state.asc(), TI.map_index.asc())
        elif order_by == '-state':
            query = query.order_by(TI.state.desc(), TI.map_index.asc())
        elif order_by == '-map_index':
            query = query.order_by(TI.map_index.desc())
        else:
            raise BadRequest(detail=f"Ordering with '{order_by}' is not supported")
    else:
        query = query.order_by(TI.map_index.asc())

    task_instances = query.offset(offset).limit(limit).all()
    return task_instance_collection_schema.dump(
        TaskInstanceCollection(task_instances=task_instances, total_entries=total_entries)
    )
Exemplo n.º 27
0
def get_dag_details(*, dag_id: str) -> APIResponse:
    """Get details of DAG."""
    dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id)
    if not dag:
        raise NotFound("DAG not found", detail=f"The DAG with dag_id: {dag_id} was not found")
    return dag_detail_schema.dump(dag)
Exemplo n.º 28
0
def patch_user(*,
               username: str,
               update_mask: UpdateMask = None) -> APIResponse:
    """Update a user"""
    try:
        data = user_schema.load(request.json)
    except ValidationError as e:
        raise BadRequest(detail=str(e.messages))

    security_manager = get_airflow_app().appbuilder.sm

    user = security_manager.find_user(username=username)
    if user is None:
        detail = f"The User with username `{username}` was not found"
        raise NotFound(title="User not found", detail=detail)
    # Check unique username
    new_username = data.get('username')
    if new_username and new_username != username:
        if security_manager.find_user(username=new_username):
            raise AlreadyExists(
                detail=f"The username `{new_username}` already exists")

    # Check unique email
    email = data.get('email')
    if email and email != user.email:
        if security_manager.find_user(email=email):
            raise AlreadyExists(detail=f"The email `{email}` already exists")

    # Get fields to update.
    if update_mask is not None:
        masked_data = {}
        missing_mask_names = []
        for field in update_mask:
            field = field.strip()
            try:
                masked_data[field] = data[field]
            except KeyError:
                missing_mask_names.append(field)
        if missing_mask_names:
            detail = f"Unknown update masks: {', '.join(repr(n) for n in missing_mask_names)}"
            raise BadRequest(detail=detail)
        data = masked_data

    roles_to_update: Optional[List[Role]]
    if "roles" in data:
        roles_to_update = []
        missing_role_names = []
        for role_data in data.pop("roles", ()):
            role_name = role_data["name"]
            role = security_manager.find_role(role_name)
            if role is None:
                missing_role_names.append(role_name)
            else:
                roles_to_update.append(role)
        if missing_role_names:
            detail = f"Unknown roles: {', '.join(repr(n) for n in missing_role_names)}"
            raise BadRequest(detail=detail)
    else:
        roles_to_update = None  # Don't change existing value.

    if "password" in data:
        user.password = generate_password_hash(data.pop("password"))
    if roles_to_update is not None:
        user.roles = roles_to_update
    for key, value in data.items():
        setattr(user, key, value)
    security_manager.update_user(user)

    return user_schema.dump(user)
Exemplo n.º 29
0
def get_log(
    *,
    dag_id: str,
    dag_run_id: str,
    task_id: str,
    task_try_number: int,
    full_content: bool = False,
    token: Optional[str] = None,
    session: Session = NEW_SESSION,
) -> APIResponse:
    """Get logs for specific task instance"""
    key = get_airflow_app().config["SECRET_KEY"]
    if not token:
        metadata = {}
    else:
        try:
            metadata = URLSafeSerializer(key).loads(token)
        except BadSignature:
            raise BadRequest(
                "Bad Signature. Please use only the tokens provided by the API."
            )

    if metadata.get('download_logs') and metadata['download_logs']:
        full_content = True

    if full_content:
        metadata['download_logs'] = True
    else:
        metadata['download_logs'] = False

    task_log_reader = TaskLogReader()
    if not task_log_reader.supports_read:
        raise BadRequest("Task log handler does not support read logs.")

    ti = (session.query(TaskInstance).filter(
        TaskInstance.task_id == task_id,
        TaskInstance.dag_id == dag_id,
        TaskInstance.run_id == dag_run_id,
    ).join(TaskInstance.dag_run).one_or_none())
    if ti is None:
        metadata['end_of_log'] = True
        raise NotFound(title="TaskInstance not found")

    dag = get_airflow_app().dag_bag.get_dag(dag_id)
    if dag:
        try:
            ti.task = dag.get_task(ti.task_id)
        except TaskNotFound:
            pass

    return_type = request.accept_mimetypes.best_match(
        ['text/plain', 'application/json'])

    # return_type would be either the above two or None
    logs: Any
    if return_type == 'application/json' or return_type is None:  # default
        logs, metadata = task_log_reader.read_log_chunks(
            ti, task_try_number, metadata)
        logs = logs[0] if task_try_number is not None else logs
        # we must have token here, so we can safely ignore it
        token = URLSafeSerializer(key).dumps(
            metadata)  # type: ignore[assignment]
        return logs_schema.dump(
            LogResponseObject(continuation_token=token, content=logs))
    # text/plain. Stream
    logs = task_log_reader.read_log_stream(ti, task_try_number, metadata)

    return Response(logs, headers={"Content-Type": return_type})