def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Clear task instances.""" body = get_json_request_dict() try: data = clear_task_instance_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: error_message = f"Dag id {dag_id} not found" raise NotFound(error_message) reset_dag_runs = data.pop('reset_dag_runs') dry_run = data.pop('dry_run') # We always pass dry_run here, otherwise this would try to confirm on the terminal! task_instances = dag.clear(dry_run=True, dag_bag=get_airflow_app().dag_bag, **data) if not dry_run: clear_task_instances( task_instances.all(), session, dag=dag, dag_run_state=DagRunState.QUEUED if reset_dag_runs else False, ) return task_instance_reference_collection_schema.dump( TaskInstanceReferenceCollection(task_instances=task_instances.all()))
def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Clear task instances.""" body = get_json_request_dict() try: data = clear_task_instance_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: error_message = f"Dag id {dag_id} not found" raise NotFound(error_message) reset_dag_runs = data.pop('reset_dag_runs') dry_run = data.pop('dry_run') # We always pass dry_run here, otherwise this would try to confirm on the terminal! dag_run_id = data.pop('dag_run_id', None) future = data.pop('include_future', False) past = data.pop('include_past', False) downstream = data.pop('include_downstream', False) upstream = data.pop('include_upstream', False) if dag_run_id is not None: dag_run: Optional[DR] = ( session.query(DR).filter(DR.dag_id == dag_id, DR.run_id == dag_run_id).one_or_none() ) if dag_run is None: error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' raise NotFound(error_message) data['start_date'] = dag_run.logical_date data['end_date'] = dag_run.logical_date if past: data['start_date'] = None if future: data['end_date'] = None task_ids = data.pop('task_ids', None) if task_ids is not None: task_id = [task[0] if isinstance(task, tuple) else task for task in task_ids] dag = dag.partial_subset( task_ids_or_regex=task_id, include_downstream=downstream, include_upstream=upstream, ) if len(dag.task_dict) > 1: # If we had upstream/downstream etc then also include those! task_ids.extend(tid for tid in dag.task_dict if tid != task_id) task_instances = dag.clear(dry_run=True, dag_bag=get_airflow_app().dag_bag, task_ids=task_ids, **data) if not dry_run: clear_task_instances( task_instances.all(), session, dag=dag, dag_run_state=DagRunState.QUEUED if reset_dag_runs else False, ) return task_instance_reference_collection_schema.dump( TaskInstanceReferenceCollection(task_instances=task_instances.all()) )
def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Trigger a DAG.""" dm = session.query(DagModel).filter(DagModel.dag_id == dag_id).first() if not dm: raise NotFound(title="DAG not found", detail=f"DAG with dag_id: '{dag_id}' not found") if dm.has_import_errors: raise BadRequest( title="DAG cannot be triggered", detail=f"DAG with dag_id: '{dag_id}' has import errors", ) try: post_body = dagrun_schema.load(get_json_request_dict(), session=session) except ValidationError as err: raise BadRequest(detail=str(err)) logical_date = pendulum.instance(post_body["execution_date"]) run_id = post_body["run_id"] dagrun_instance = (session.query(DagRun).filter( DagRun.dag_id == dag_id, or_(DagRun.run_id == run_id, DagRun.execution_date == logical_date), ).first()) if not dagrun_instance: try: dag = get_airflow_app().dag_bag.get_dag(dag_id) dag_run = dag.create_dagrun( run_type=DagRunType.MANUAL, run_id=run_id, execution_date=logical_date, data_interval=dag.timetable.infer_manual_data_interval( run_after=logical_date), state=DagRunState.QUEUED, conf=post_body.get("conf"), external_trigger=True, dag_hash=get_airflow_app().dag_bag.dags_hash.get(dag_id), ) return dagrun_schema.dump(dag_run) except ValueError as ve: raise BadRequest(detail=str(ve)) if dagrun_instance.execution_date == logical_date: raise AlreadyExists(detail=( f"DAGRun with DAG ID: '{dag_id}' and " f"DAGRun logical date: '{logical_date.isoformat(sep=' ')}' already exists" ), ) raise AlreadyExists( detail= f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: '{run_id}' already exists" )
def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: """Set a state of a dag run.""" dag_run: Optional[DagRun] = (session.query(DagRun).filter( DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none()) if dag_run is None: error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' raise NotFound(error_message) try: post_body = set_dagrun_state_form_schema.load(get_json_request_dict()) except ValidationError as err: raise BadRequest(detail=str(err)) state = post_body['state'] dag = get_airflow_app().dag_bag.get_dag(dag_id) if state == DagRunState.SUCCESS: set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, commit=True) elif state == DagRunState.QUEUED: set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id, commit=True) else: set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id, commit=True) dag_run = session.query(DagRun).get(dag_run.id) return dagrun_schema.dump(dag_run)
def patch_role(*, role_name: str, update_mask: UpdateMask = None) -> APIResponse: """Update a role""" appbuilder = get_airflow_app().appbuilder security_manager = appbuilder.sm body = request.json try: data = role_schema.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) role = security_manager.find_role(name=role_name) if not role: raise NotFound(title="Role not found", detail=f"Role with name {role_name!r} was not found") if update_mask: update_mask = [i.strip() for i in update_mask] data_ = {} for field in update_mask: if field in data and not field == "permissions": data_[field] = data[field] elif field == "actions": data_["permissions"] = data['permissions'] else: raise BadRequest(detail=f"'{field}' in update_mask is unknown") data = data_ if "permissions" in data: perms = [(item["action"]["name"], item["resource"]["name"]) for item in data["permissions"] if item] _check_action_and_resource(security_manager, perms) security_manager.bulk_sync_roles([{"role": role_name, "perms": perms}]) new_name = data.get("name") if new_name is not None and new_name != role.name: security_manager.update_role(role_id=role.id, name=new_name) return role_schema.dump(role)
def get_dags( *, limit: int, offset: int = 0, tags: Optional[Collection[str]] = None, dag_id_pattern: Optional[str] = None, only_active: bool = True, session: Session = NEW_SESSION, ) -> APIResponse: """Get all DAGs.""" if only_active: dags_query = session.query(DagModel).filter(~DagModel.is_subdag, DagModel.is_active) else: dags_query = session.query(DagModel).filter(~DagModel.is_subdag) if dag_id_pattern: dags_query = dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%')) readable_dags = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) dags_query = dags_query.filter(DagModel.dag_id.in_(readable_dags)) if tags: cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags] dags_query = dags_query.filter(or_(*cond)) total_entries = dags_query.count() dags = dags_query.order_by(DagModel.dag_id).offset(offset).limit(limit).all() return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries))
def get_users(*, limit: int, order_by: str = "id", offset: Optional[str] = None) -> APIResponse: """Get users""" appbuilder = get_airflow_app().appbuilder session = appbuilder.get_session total_entries = session.query(func.count(User.id)).scalar() direction = desc if order_by.startswith("-") else asc to_replace = {"user_id": "id"} order_param = order_by.strip("-") order_param = to_replace.get(order_param, order_param) allowed_filter_attrs = [ 'id', "first_name", "last_name", "user_name", "email", "is_active", "role", ] if order_by not in allowed_filter_attrs: raise BadRequest(detail=f"Ordering with '{order_by}' is disallowed or " f"the attribute does not exist on the model") query = session.query(User) users = query.order_by(direction(getattr( User, order_param))).offset(offset).limit(limit).all() return user_collection_schema.dump( UserCollection(users=users, total_entries=total_entries))
def get_xcom_entries( *, dag_id: str, dag_run_id: str, task_id: str, limit: Optional[int], offset: Optional[int] = None, session: Session = NEW_SESSION, ) -> APIResponse: """Get all XCom values""" query = session.query(XCom) if dag_id == '~': appbuilder = get_airflow_app().appbuilder readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user) query = query.filter(XCom.dag_id.in_(readable_dag_ids)) query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id)) else: query = query.filter(XCom.dag_id == dag_id) query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id)) if task_id != '~': query = query.filter(XCom.task_id == task_id) if dag_run_id != '~': query = query.filter(DR.run_id == dag_run_id) query = query.order_by(DR.execution_date, XCom.task_id, XCom.dag_id, XCom.key) total_entries = query.count() query = query.offset(offset).limit(limit) return xcom_collection_schema.dump(XComCollection(xcom_entries=query.all(), total_entries=total_entries))
def get_role(*, role_name: str) -> APIResponse: """Get role""" ab_security_manager = get_airflow_app().appbuilder.sm role = ab_security_manager.find_role(name=role_name) if not role: raise NotFound(title="Role not found", detail=f"Role with name {role_name!r} was not found") return role_schema.dump(role)
def check_authentication() -> None: """Checks that the request has valid authorization information.""" for auth in get_airflow_app().api_auth: response = auth.requires_authentication(Response)() if response.status_code == 200: return # since this handler only checks authentication, not authorization, # we should always return 401 raise Unauthenticated(headers=response.headers)
def get_user(*, username: str) -> APIResponse: """Get a user""" ab_security_manager = get_airflow_app().appbuilder.sm user = ab_security_manager.find_user(username=username) if not user: raise NotFound( title="User not found", detail=f"The User with username `{username}` was not found") return user_collection_item_schema.dump(user)
def delete_role(*, role_name: str) -> APIResponse: """Delete a role""" ab_security_manager = get_airflow_app().appbuilder.sm role = ab_security_manager.find_role(name=role_name) if not role: raise NotFound(title="Role not found", detail=f"Role with name {role_name!r} was not found") ab_security_manager.delete_role(role_name=role_name) return NoContent, HTTPStatus.NO_CONTENT
def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION ) -> APIResponse: """Set a state of task instances.""" body = get_json_request_dict() try: data = set_task_instance_state_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) error_message = f"Dag ID {dag_id} not found" dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: raise NotFound(error_message) task_id = data['task_id'] task = dag.task_dict.get(task_id) if not task: error_message = f"Task ID {task_id} not found" raise NotFound(error_message) execution_date = data.get('execution_date') run_id = data.get('dag_run_id') if (execution_date and (session.query(TI).filter( TI.task_id == task_id, TI.dag_id == dag_id, TI.execution_date == execution_date).one_or_none()) is None): raise NotFound( detail= f"Task instance not found for task {task_id!r} on execution_date {execution_date}" ) if run_id and not session.query(TI).get({ 'task_id': task_id, 'dag_id': dag_id, 'run_id': run_id, 'map_index': -1 }): error_message = f"Task instance not found for task {task_id!r} on DAG run with ID {run_id!r}" raise NotFound(detail=error_message) tis = dag.set_task_instance_state( task_id=task_id, run_id=run_id, execution_date=execution_date, state=data["new_state"], upstream=data["include_upstream"], downstream=data["include_downstream"], future=data["include_future"], past=data["include_past"], commit=not data["dry_run"], session=session, ) return task_instance_reference_collection_schema.dump( TaskInstanceReferenceCollection(task_instances=tis))
def get_permissions(*, limit: int, offset: Optional[int] = None) -> APIResponse: """Get permissions""" session = get_airflow_app().appbuilder.get_session total_entries = session.query(func.count(Action.id)).scalar() query = session.query(Action) actions = query.offset(offset).limit(limit).all() return action_collection_schema.dump( ActionCollection(actions=actions, total_entries=total_entries))
def delete_user(*, username: str) -> APIResponse: """Delete a user""" security_manager = get_airflow_app().appbuilder.sm user = security_manager.find_user(username=username) if user is None: detail = f"The User with username `{username}` was not found" raise NotFound(title="User not found", detail=detail) user.roles = [] # Clear foreign keys on this user first. security_manager.get_session.delete(user) security_manager.get_session.commit() return NoContent, HTTPStatus.NO_CONTENT
def auth_current_user() -> Optional[User]: """Authenticate and set current user if Authorization header exists""" auth = request.authorization if auth is None or not auth.username or not auth.password: return None ab_security_manager = get_airflow_app().appbuilder.sm user = None if ab_security_manager.auth_type == AUTH_LDAP: user = ab_security_manager.auth_user_ldap(auth.username, auth.password) if user is None: user = ab_security_manager.auth_user_db(auth.username, auth.password) if user is not None: login_user(user, remember=False) return user
def requires_access(permissions: Optional[Sequence[Tuple[str, str]]] = None) -> Callable[[T], T]: """Factory for decorator that checks current user's permissions against required permissions.""" appbuilder = get_airflow_app().appbuilder appbuilder.sm.sync_resource_permissions(permissions) def requires_access_decorator(func: T): @wraps(func) def decorated(*args, **kwargs): check_authentication() if appbuilder.sm.check_authorization(permissions, kwargs.get('dag_id')): return func(*args, **kwargs) raise PermissionDenied() return cast(T, decorated) return requires_access_decorator
def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: """Clear a dag run.""" dag_run: Optional[DagRun] = (session.query(DagRun).filter( DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none()) if dag_run is None: error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' raise NotFound(error_message) try: post_body = clear_dagrun_form_schema.load(get_json_request_dict()) except ValidationError as err: raise BadRequest(detail=str(err)) dry_run = post_body.get('dry_run', False) dag = get_airflow_app().dag_bag.get_dag(dag_id) start_date = dag_run.logical_date end_date = dag_run.logical_date if dry_run: task_instances = dag.clear( start_date=start_date, end_date=end_date, task_ids=None, include_subdags=True, include_parentdag=True, only_failed=False, dry_run=True, ) return task_instance_reference_collection_schema.dump( TaskInstanceReferenceCollection(task_instances=task_instances)) else: dag.clear( start_date=start_date, end_date=end_date, task_ids=None, include_subdags=True, include_parentdag=True, only_failed=False, ) dag_run.refresh_from_db() return dagrun_schema.dump(dag_run)
def get_dag_runs( *, dag_id: str, start_date_gte: Optional[str] = None, start_date_lte: Optional[str] = None, execution_date_gte: Optional[str] = None, execution_date_lte: Optional[str] = None, end_date_gte: Optional[str] = None, end_date_lte: Optional[str] = None, state: Optional[List[str]] = None, offset: Optional[int] = None, limit: Optional[int] = None, order_by: str = "id", session: Session = NEW_SESSION, ): """Get all DAG Runs.""" query = session.query(DagRun) # This endpoint allows specifying ~ as the dag_id to retrieve DAG Runs for all DAGs. if dag_id == "~": appbuilder = get_airflow_app().appbuilder query = query.filter( DagRun.dag_id.in_(appbuilder.sm.get_readable_dag_ids(g.user))) else: query = query.filter(DagRun.dag_id == dag_id) if state: query = query.filter(DagRun.state.in_(state)) dag_run, total_entries = _fetch_dag_runs( query, end_date_gte=end_date_gte, end_date_lte=end_date_lte, execution_date_gte=execution_date_gte, execution_date_lte=execution_date_lte, start_date_gte=start_date_gte, start_date_lte=start_date_lte, limit=limit, offset=offset, order_by=order_by, ) return dagrun_collection_schema.dump( DAGRunCollection(dag_runs=dag_run, total_entries=total_entries))
def post_user() -> APIResponse: """Create a new user""" try: data = user_schema.load(request.json) except ValidationError as e: raise BadRequest(detail=str(e.messages)) security_manager = get_airflow_app().appbuilder.sm username = data["username"] email = data["email"] if security_manager.find_user(username=username): detail = f"Username `{username}` already exists. Use PATCH to update." raise AlreadyExists(detail=detail) if security_manager.find_user(email=email): detail = f"The email `{email}` is already taken." raise AlreadyExists(detail=detail) roles_to_add = [] missing_role_names = [] for role_data in data.pop("roles", ()): role_name = role_data["name"] role = security_manager.find_role(role_name) if role is None: missing_role_names.append(role_name) else: roles_to_add.append(role) if missing_role_names: detail = f"Unknown roles: {', '.join(repr(n) for n in missing_role_names)}" raise BadRequest(detail=detail) if not roles_to_add: # No roles provided, use the F.A.B's default registered user role. roles_to_add.append( security_manager.find_role( security_manager.auth_user_registration_role)) user = security_manager.add_user(role=roles_to_add, **data) if not user: detail = f"Failed to add user `{username}`." return Unknown(detail=detail) return user_schema.dump(user)
def post_role() -> APIResponse: """Create a new role""" appbuilder = get_airflow_app().appbuilder security_manager = appbuilder.sm body = request.json try: data = role_schema.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) role = security_manager.find_role(name=data['name']) if not role: perms = [(item['action']['name'], item['resource']['name']) for item in data['permissions'] if item] _check_action_and_resource(security_manager, perms) security_manager.bulk_sync_roles([{ "role": data["name"], "perms": perms }]) return role_schema.dump(role) detail = f"Role with name {role.name!r} already exists; please update with the PATCH endpoint" raise AlreadyExists(detail=detail)
def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pattern=None, update_mask=None): """Patch multiple DAGs.""" try: patch_body = dag_schema.load(request.json, session=session) except ValidationError as err: raise BadRequest(detail=str(err.messages)) if update_mask: patch_body_ = {} if update_mask != ['is_paused']: raise BadRequest(detail="Only `is_paused` field can be updated through the REST API") update_mask = update_mask[0] patch_body_[update_mask] = patch_body[update_mask] patch_body = patch_body_ if only_active: dags_query = session.query(DagModel).filter(~DagModel.is_subdag, DagModel.is_active) else: dags_query = session.query(DagModel).filter(~DagModel.is_subdag) if dag_id_pattern == '~': dag_id_pattern = '%' dags_query = dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%')) editable_dags = get_airflow_app().appbuilder.sm.get_editable_dag_ids(g.user) dags_query = dags_query.filter(DagModel.dag_id.in_(editable_dags)) if tags: cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags] dags_query = dags_query.filter(or_(*cond)) total_entries = dags_query.count() dags = dags_query.order_by(DagModel.dag_id).offset(offset).limit(limit).all() dags_to_update = {dag.dag_id for dag in dags} session.query(DagModel).filter(DagModel.dag_id.in_(dags_to_update)).update( {DagModel.is_paused: patch_body['is_paused']}, synchronize_session='fetch' ) session.flush() return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries))
def get_extra_links( *, dag_id: str, dag_run_id: str, task_id: str, session: Session = NEW_SESSION, ) -> APIResponse: """Get extra links for task instance""" from airflow.models.taskinstance import TaskInstance dagbag: DagBag = get_airflow_app().dag_bag dag: DAG = dagbag.get_dag(dag_id) if not dag: raise NotFound("DAG not found", detail=f'DAG with ID = "{dag_id}" not found') try: task = dag.get_task(task_id) except TaskNotFound: raise NotFound("Task not found", detail=f'Task with ID = "{task_id}" not found') ti = (session.query(TaskInstance).filter( TaskInstance.dag_id == dag_id, TaskInstance.run_id == dag_run_id, TaskInstance.task_id == task_id, ).one_or_none()) if not ti: raise NotFound("DAG Run not found", detail=f'DAG Run with ID = "{dag_run_id}" not found') all_extra_link_pairs = ((link_name, task.get_extra_links(ti, link_name)) for link_name in task.extra_links) all_extra_links = { link_name: link_url if link_url else None for link_name, link_url in sorted(all_extra_link_pairs) } return all_extra_links
def get_roles(*, order_by: str = "name", limit: int, offset: Optional[int] = None) -> APIResponse: """Get roles""" appbuilder = get_airflow_app().appbuilder session = appbuilder.get_session total_entries = session.query(func.count(Role.id)).scalar() direction = desc if order_by.startswith("-") else asc to_replace = {"role_id": "id"} order_param = order_by.strip("-") order_param = to_replace.get(order_param, order_param) allowed_filter_attrs = ["role_id", "name"] if order_by not in allowed_filter_attrs: raise BadRequest(detail=f"Ordering with '{order_by}' is disallowed or " f"the attribute does not exist on the model") query = session.query(Role) roles = query.order_by(direction(getattr( Role, order_param))).offset(offset).limit(limit).all() return role_collection_schema.dump( RoleCollection(roles=roles, total_entries=total_entries))
def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse: """Get list of DAG Runs""" body = get_json_request_dict() try: data = dagruns_batch_form_schema.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) appbuilder = get_airflow_app().appbuilder readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user) query = session.query(DagRun) if data.get("dag_ids"): dag_ids = set(data["dag_ids"]) & set(readable_dag_ids) query = query.filter(DagRun.dag_id.in_(dag_ids)) else: query = query.filter(DagRun.dag_id.in_(readable_dag_ids)) states = data.get("states") if states: query = query.filter(DagRun.state.in_(states)) dag_runs, total_entries = _fetch_dag_runs( query, end_date_gte=data["end_date_gte"], end_date_lte=data["end_date_lte"], execution_date_gte=data["execution_date_gte"], execution_date_lte=data["execution_date_lte"], start_date_gte=data["start_date_gte"], start_date_lte=data["start_date_lte"], limit=data["page_limit"], offset=data["page_offset"], order_by=data.get("order_by", "id"), ) return dagrun_collection_schema.dump( DAGRunCollection(dag_runs=dag_runs, total_entries=total_entries))
def get_mapped_task_instances( *, dag_id: str, dag_run_id: str, task_id: str, execution_date_gte: Optional[str] = None, execution_date_lte: Optional[str] = None, start_date_gte: Optional[str] = None, start_date_lte: Optional[str] = None, end_date_gte: Optional[str] = None, end_date_lte: Optional[str] = None, duration_gte: Optional[float] = None, duration_lte: Optional[float] = None, state: Optional[List[str]] = None, pool: Optional[List[str]] = None, queue: Optional[List[str]] = None, limit: Optional[int] = None, offset: Optional[int] = None, order_by: Optional[str] = None, session: Session = NEW_SESSION, ) -> APIResponse: """Get list of task instances.""" # Because state can be 'none' states = _convert_state(state) base_query = ( session.query(TI) .filter(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == task_id, TI.map_index >= 0) .join(TI.dag_run) ) # 0 can mean a mapped TI that expanded to an empty list, so it is not an automatic 404 if base_query.with_entities(func.count('*')).scalar() == 0: dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: error_message = f"DAG {dag_id} not found" raise NotFound(error_message) task = dag.get_task(task_id) if not task: error_message = f"Task id {task_id} not found" raise NotFound(error_message) if not task.is_mapped: error_message = f"Task id {task_id} is not mapped" raise NotFound(error_message) # Other search criteria query = _apply_range_filter( base_query, key=DR.execution_date, value_range=(execution_date_gte, execution_date_lte), ) query = _apply_range_filter(query, key=TI.start_date, value_range=(start_date_gte, start_date_lte)) query = _apply_range_filter(query, key=TI.end_date, value_range=(end_date_gte, end_date_lte)) query = _apply_range_filter(query, key=TI.duration, value_range=(duration_gte, duration_lte)) query = _apply_array_filter(query, key=TI.state, values=states) query = _apply_array_filter(query, key=TI.pool, values=pool) query = _apply_array_filter(query, key=TI.queue, values=queue) # Count elements before joining extra columns total_entries = query.with_entities(func.count('*')).scalar() # Add SLA miss query = ( query.join( SlaMiss, and_( SlaMiss.dag_id == TI.dag_id, SlaMiss.task_id == TI.task_id, SlaMiss.execution_date == DR.execution_date, ), isouter=True, ) .add_entity(SlaMiss) .options(joinedload(TI.rendered_task_instance_fields)) ) if order_by: if order_by == 'state': query = query.order_by(TI.state.asc(), TI.map_index.asc()) elif order_by == '-state': query = query.order_by(TI.state.desc(), TI.map_index.asc()) elif order_by == '-map_index': query = query.order_by(TI.map_index.desc()) else: raise BadRequest(detail=f"Ordering with '{order_by}' is not supported") else: query = query.order_by(TI.map_index.asc()) task_instances = query.offset(offset).limit(limit).all() return task_instance_collection_schema.dump( TaskInstanceCollection(task_instances=task_instances, total_entries=total_entries) )
def get_dag_details(*, dag_id: str) -> APIResponse: """Get details of DAG.""" dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: raise NotFound("DAG not found", detail=f"The DAG with dag_id: {dag_id} was not found") return dag_detail_schema.dump(dag)
def patch_user(*, username: str, update_mask: UpdateMask = None) -> APIResponse: """Update a user""" try: data = user_schema.load(request.json) except ValidationError as e: raise BadRequest(detail=str(e.messages)) security_manager = get_airflow_app().appbuilder.sm user = security_manager.find_user(username=username) if user is None: detail = f"The User with username `{username}` was not found" raise NotFound(title="User not found", detail=detail) # Check unique username new_username = data.get('username') if new_username and new_username != username: if security_manager.find_user(username=new_username): raise AlreadyExists( detail=f"The username `{new_username}` already exists") # Check unique email email = data.get('email') if email and email != user.email: if security_manager.find_user(email=email): raise AlreadyExists(detail=f"The email `{email}` already exists") # Get fields to update. if update_mask is not None: masked_data = {} missing_mask_names = [] for field in update_mask: field = field.strip() try: masked_data[field] = data[field] except KeyError: missing_mask_names.append(field) if missing_mask_names: detail = f"Unknown update masks: {', '.join(repr(n) for n in missing_mask_names)}" raise BadRequest(detail=detail) data = masked_data roles_to_update: Optional[List[Role]] if "roles" in data: roles_to_update = [] missing_role_names = [] for role_data in data.pop("roles", ()): role_name = role_data["name"] role = security_manager.find_role(role_name) if role is None: missing_role_names.append(role_name) else: roles_to_update.append(role) if missing_role_names: detail = f"Unknown roles: {', '.join(repr(n) for n in missing_role_names)}" raise BadRequest(detail=detail) else: roles_to_update = None # Don't change existing value. if "password" in data: user.password = generate_password_hash(data.pop("password")) if roles_to_update is not None: user.roles = roles_to_update for key, value in data.items(): setattr(user, key, value) security_manager.update_user(user) return user_schema.dump(user)
def get_log( *, dag_id: str, dag_run_id: str, task_id: str, task_try_number: int, full_content: bool = False, token: Optional[str] = None, session: Session = NEW_SESSION, ) -> APIResponse: """Get logs for specific task instance""" key = get_airflow_app().config["SECRET_KEY"] if not token: metadata = {} else: try: metadata = URLSafeSerializer(key).loads(token) except BadSignature: raise BadRequest( "Bad Signature. Please use only the tokens provided by the API." ) if metadata.get('download_logs') and metadata['download_logs']: full_content = True if full_content: metadata['download_logs'] = True else: metadata['download_logs'] = False task_log_reader = TaskLogReader() if not task_log_reader.supports_read: raise BadRequest("Task log handler does not support read logs.") ti = (session.query(TaskInstance).filter( TaskInstance.task_id == task_id, TaskInstance.dag_id == dag_id, TaskInstance.run_id == dag_run_id, ).join(TaskInstance.dag_run).one_or_none()) if ti is None: metadata['end_of_log'] = True raise NotFound(title="TaskInstance not found") dag = get_airflow_app().dag_bag.get_dag(dag_id) if dag: try: ti.task = dag.get_task(ti.task_id) except TaskNotFound: pass return_type = request.accept_mimetypes.best_match( ['text/plain', 'application/json']) # return_type would be either the above two or None logs: Any if return_type == 'application/json' or return_type is None: # default logs, metadata = task_log_reader.read_log_chunks( ti, task_try_number, metadata) logs = logs[0] if task_try_number is not None else logs # we must have token here, so we can safely ignore it token = URLSafeSerializer(key).dumps( metadata) # type: ignore[assignment] return logs_schema.dump( LogResponseObject(continuation_token=token, content=logs)) # text/plain. Stream logs = task_log_reader.read_log_stream(ti, task_try_number, metadata) return Response(logs, headers={"Content-Type": return_type})