def _check_action_and_resource(sm: AirflowSecurityManager, perms: List[Tuple[str, str]]) -> None: """ Checks if the action or resource exists and raise 400 if not This function is intended for use in the REST API because it raise 400 """ for action, resource in perms: if not sm.get_action(action): raise BadRequest( detail=f"The specified action: {action!r} was not found") if not sm.get_resource(resource): raise BadRequest( detail=f"The specified resource: {resource!r} was not found")
def get_log(session, dag_id, dag_run_id, task_id, task_try_number, full_content=False, token=None): """Get logs for specific task instance""" key = current_app.config["SECRET_KEY"] if not token: metadata = {} else: try: metadata = URLSafeSerializer(key).loads(token) except BadSignature: raise BadRequest("Bad Signature. Please use only the tokens provided by the API.") if metadata.get('download_logs') and metadata['download_logs']: full_content = True if full_content: metadata['download_logs'] = True else: metadata['download_logs'] = False task_log_reader = TaskLogReader() if not task_log_reader.supports_read: raise BadRequest("Task log handler does not support read logs.") query = session.query(DagRun).filter(DagRun.dag_id == dag_id) dag_run = query.filter(DagRun.run_id == dag_run_id).first() if not dag_run: raise NotFound("DAG Run not found") ti = dag_run.get_task_instance(task_id, session) if ti is None: metadata['end_of_log'] = True raise BadRequest(detail="Task instance did not exist in the DB") dag = current_app.dag_bag.get_dag(dag_id) if dag: ti.task = dag.get_task(ti.task_id) return_type = request.accept_mimetypes.best_match(['text/plain', 'application/json']) # return_type would be either the above two or None if return_type == 'application/json' or return_type is None: # default logs, metadata = task_log_reader.read_log_chunks(ti, task_try_number, metadata) logs = logs[0] if task_try_number is not None else logs token = URLSafeSerializer(key).dumps(metadata) return logs_schema.dump(LogResponseObject(continuation_token=token, content=logs)) # text/plain. Stream logs = task_log_reader.read_log_stream(ti, task_try_number, metadata) return Response(logs, headers={"Content-Type": return_type})
def patch_pool(pool_name, session, update_mask=None): """ Update a pool """ # Only slots can be modified in 'default_pool' try: if pool_name == Pool.DEFAULT_POOL_NAME and request.json[ "name"] != Pool.DEFAULT_POOL_NAME: if update_mask and len( update_mask) == 1 and update_mask[0].strip() == "slots": pass else: raise BadRequest( detail="Default Pool's name can't be modified") except KeyError: pass pool = session.query(Pool).filter(Pool.pool == pool_name).first() if not pool: raise NotFound(detail=f"Pool with name:'{pool_name}' not found") try: patch_body = pool_schema.load(request.json) except ValidationError as err: raise BadRequest(detail=str(err.messages)) if update_mask: update_mask = [i.strip() for i in update_mask] _patch_body = {} try: update_mask = [ pool_schema.declared_fields[field].attribute if pool_schema.declared_fields[field].attribute else field for field in update_mask ] except KeyError as err: raise BadRequest( detail=f"Invalid field: {err.args[0]} in update mask") _patch_body = {field: patch_body[field] for field in update_mask} patch_body = _patch_body else: for field in ["name", "slots"]: if field not in request.json.keys(): raise BadRequest(detail=f"'{field}' is a required property") for key, value in patch_body.items(): setattr(pool, key, value) session.commit() return pool_schema.dump(pool)
def get_dag_runs_batch(session): """Get list of DAG Runs""" body = request.get_json() try: data = dagruns_batch_form_schema.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) appbuilder = current_app.appbuilder readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user) query = session.query(DagRun) if data.get("dag_ids"): dag_ids = set(data["dag_ids"]) & set(readable_dag_ids) query = query.filter(DagRun.dag_id.in_(dag_ids)) else: query = query.filter(DagRun.dag_id.in_(readable_dag_ids)) dag_runs, total_entries = _fetch_dag_runs( query, data["end_date_gte"], data["end_date_lte"], data["execution_date_gte"], data["execution_date_lte"], data["start_date_gte"], data["start_date_lte"], data["page_limit"], data["page_offset"], order_by=data.get('order_by', "id"), ) return dagrun_collection_schema.dump( DAGRunCollection(dag_runs=dag_runs, total_entries=total_entries))
def post_clear_task_instances(dag_id: str, session=None): """Clear task instances.""" body = request.get_json() try: data = clear_task_instance_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) dag = current_app.dag_bag.get_dag(dag_id) if not dag: error_message = f"Dag id {dag_id} not found" raise NotFound(error_message) reset_dag_runs = data.pop('reset_dag_runs') dry_run = data.pop('dry_run') # We always pass dry_run here, otherwise this would try to confirm on the terminal! task_instances = dag.clear(dry_run=True, dag_bag=current_app.dag_bag, **data) if not dry_run: clear_task_instances( task_instances, session, dag=dag, dag_run_state=State.RUNNING if reset_dag_runs else False) task_instances = task_instances.join( DR, and_(DR.dag_id == TI.dag_id, DR.execution_date == TI.execution_date)).add_column(DR.run_id) return task_instance_reference_collection_schema.dump( TaskInstanceReferenceCollection(task_instances=task_instances.all()))
def get_task_instances_batch(session=None): """ Get list of task instances. """ body = request.get_json() try: data = task_instance_batch_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) base_query = session.query(TI) base_query = _apply_array_filter(base_query, key=TI.dag_id, values=data["dag_ids"]) base_query = _apply_range_filter( base_query, key=TI.execution_date, value_range=(data["execution_date_gte"], data["execution_date_lte"]), ) base_query = _apply_range_filter( base_query, key=TI.start_date, value_range=(data["start_date_gte"], data["start_date_lte"]), ) base_query = _apply_range_filter(base_query, key=TI.end_date, value_range=(data["end_date_gte"], data["end_date_lte"])) base_query = _apply_range_filter(base_query, key=TI.duration, value_range=(data["duration_gte"], data["duration_lte"])) base_query = _apply_array_filter(base_query, key=TI.state, values=data["state"]) base_query = _apply_array_filter(base_query, key=TI.pool, values=data["pool"]) base_query = _apply_array_filter(base_query, key=TI.queue, values=data["queue"]) # Count elements before joining extra columns total_entries = base_query.with_entities(func.count('*')).scalar() # Add join base_query = base_query.join( SlaMiss, and_( SlaMiss.dag_id == TI.dag_id, SlaMiss.task_id == TI.task_id, SlaMiss.execution_date == TI.execution_date, ), isouter=True, ) ti_query = base_query.add_entity(SlaMiss) task_instances = ti_query.all() return task_instance_collection_schema.dump( TaskInstanceCollection(task_instances=task_instances, total_entries=total_entries))
def get_users(*, limit: int, order_by: str = "id", offset: Optional[str] = None) -> APIResponse: """Get users""" appbuilder = get_airflow_app().appbuilder session = appbuilder.get_session total_entries = session.query(func.count(User.id)).scalar() direction = desc if order_by.startswith("-") else asc to_replace = {"user_id": "id"} order_param = order_by.strip("-") order_param = to_replace.get(order_param, order_param) allowed_filter_attrs = [ 'id', "first_name", "last_name", "user_name", "email", "is_active", "role", ] if order_by not in allowed_filter_attrs: raise BadRequest(detail=f"Ordering with '{order_by}' is disallowed or " f"the attribute does not exist on the model") query = session.query(User) users = query.order_by(direction(getattr( User, order_param))).offset(offset).limit(limit).all() return user_collection_schema.dump( UserCollection(users=users, total_entries=total_entries))
def get_dag_runs_batch(session): """ Get list of DAG Runs """ body = request.get_json() try: data = dagruns_batch_form_schema.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) query = session.query(DagRun) if data["dag_ids"]: query = query.filter(DagRun.dag_id.in_(data["dag_ids"])) dag_runs, total_entries = _fetch_dag_runs( query, session, data["end_date_gte"], data["end_date_lte"], data["execution_date_gte"], data["execution_date_lte"], data["start_date_gte"], data["start_date_lte"], data["page_limit"], data["page_offset"], ) return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs, total_entries=total_entries))
def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Clear task instances.""" body = request.get_json() try: data = clear_task_instance_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) dag = current_app.dag_bag.get_dag(dag_id) if not dag: error_message = f"Dag id {dag_id} not found" raise NotFound(error_message) reset_dag_runs = data.pop('reset_dag_runs') dry_run = data.pop('dry_run') # We always pass dry_run here, otherwise this would try to confirm on the terminal! task_instances = dag.clear(dry_run=True, dag_bag=current_app.dag_bag, **data) if not dry_run: clear_task_instances( task_instances.all(), session, dag=dag, dag_run_state=DagRunState.QUEUED if reset_dag_runs else False, ) return task_instance_reference_collection_schema.dump( TaskInstanceReferenceCollection(task_instances=task_instances.all()))
def post_set_task_instances_state(dag_id, session): """Set a state of task instances.""" body = request.get_json() try: data = set_task_instance_state_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) error_message = f"Dag ID {dag_id} not found" dag = current_app.dag_bag.get_dag(dag_id) if not dag: raise NotFound(error_message) task_id = data['task_id'] task = dag.task_dict.get(task_id) if not task: error_message = f"Task ID {task_id} not found" raise NotFound(error_message) tis = dag.set_task_instance_state( task_id=task_id, execution_date=data["execution_date"], state=data["new_state"], upstream=data["include_upstream"], downstream=data["include_downstream"], future=data["include_future"], past=data["include_past"], commit=not data["dry_run"], session=session, ) return task_instance_reference_collection_schema.dump( TaskInstanceReferenceCollection(task_instances=tis))
def post_dag_run(dag_id, session): """Trigger a DAG.""" if not session.query(DagModel).filter(DagModel.dag_id == dag_id).first(): raise NotFound(title="DAG not found", detail=f"DAG with dag_id: '{dag_id}' not found") try: post_body = dagrun_schema.load(request.json, session=session) except ValidationError as err: raise BadRequest(detail=str(err)) dagrun_instance = (session.query(DagRun).filter( DagRun.dag_id == dag_id, or_(DagRun.run_id == post_body["run_id"], DagRun.execution_date == post_body["execution_date"]), ).first()) if not dagrun_instance: dag_run = DagRun(dag_id=dag_id, run_type=DagRunType.MANUAL, **post_body) session.add(dag_run) session.commit() return dagrun_schema.dump(dag_run) if dagrun_instance.execution_date == post_body["execution_date"]: raise AlreadyExists( detail=f"DAGRun with DAG ID: '{dag_id}' and " f"DAGRun ExecutionDate: '{post_body['execution_date']}' already exists" ) raise AlreadyExists( detail= f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: '{post_body['run_id']}' already exists" )
def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: """Set a state of a dag run.""" dag_run: Optional[DagRun] = (session.query(DagRun).filter( DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none()) if dag_run is None: error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' raise NotFound(error_message) try: post_body = set_dagrun_state_form_schema.load(get_json_request_dict()) except ValidationError as err: raise BadRequest(detail=str(err)) state = post_body['state'] dag = get_airflow_app().dag_bag.get_dag(dag_id) if state == DagRunState.SUCCESS: set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, commit=True) elif state == DagRunState.QUEUED: set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id, commit=True) else: set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id, commit=True) dag_run = session.query(DagRun).get(dag_run.id) return dagrun_schema.dump(dag_run)
def post_clear_task_instances(dag_id: str, session=None): """ Clear task instances. """ body = request.get_json() try: data = clear_task_instance_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) dag = current_app.dag_bag.get_dag(dag_id) if not dag: error_message = "Dag id {} not found".format(dag_id) raise NotFound(error_message) reset_dag_runs = data.pop('reset_dag_runs') task_instances = dag.clear(get_tis=True, **data) if not data["dry_run"]: clear_task_instances( task_instances, session, dag=dag, activate_dag_runs=False, # We will set DagRun state later. ) if reset_dag_runs: dag.set_dag_runs_state( session=session, start_date=data["start_date"], end_date=data["end_date"], state=State.RUNNING, ) task_instances = task_instances.join( DR, and_(DR.dag_id == TI.dag_id, DR.execution_date == TI.execution_date)).add_column(DR.run_id) return task_instance_reference_collection_schema.dump( TaskInstanceReferenceCollection(task_instances=task_instances.all()))
def delete_pool(pool_name: str, session): """Delete a pool""" if pool_name == "default_pool": raise BadRequest(detail="Default Pool can't be deleted") elif session.query(Pool).filter(Pool.pool == pool_name).delete() == 0: raise NotFound(detail=f"Pool with name:'{pool_name}' not found") else: return Response(status=204)
def post_user(): """Create a new user""" try: data = user_schema.load(request.json) except ValidationError as e: raise BadRequest(detail=str(e.messages)) security_manager = current_app.appbuilder.sm username = data["username"] email = data["email"] if security_manager.find_user(username=username): detail = f"Username `{username}` already exists. Use PATCH to update." raise AlreadyExists(detail=detail) if security_manager.find_user(email=email): detail = f"The email `{email}` is already taken." raise AlreadyExists(detail=detail) roles_to_add = [] missing_role_names = [] for role_data in data.pop("roles", ()): role_name = role_data["name"] role = security_manager.find_role(role_name) if role is None: missing_role_names.append(role_name) else: roles_to_add.append(role) if missing_role_names: detail = f"Unknown roles: {', '.join(repr(n) for n in missing_role_names)}" raise BadRequest(detail=detail) if roles_to_add: default_role = roles_to_add.pop() else: # No roles provided, use the F.A.B's default registered user role. default_role = security_manager.find_role( security_manager.auth_user_registration_role) user = security_manager.add_user(role=default_role, **data) if not user: detail = f"Failed to add user `{username}`." return Unknown(detail=detail) if roles_to_add: user.roles.extend(roles_to_add) security_manager.update_user(user) return user_schema.dump(user)
def post_dag_run(dag_id, session): """Trigger a DAG.""" if not session.query(DagModel).filter(DagModel.dag_id == dag_id).first(): raise NotFound(title="DAG not found", detail=f"DAG with dag_id: '{dag_id}' not found") try: post_body = dagrun_schema.load(request.json, session=session) except ValidationError as err: raise BadRequest(detail=str(err)) logical_date = pendulum.instance(post_body["execution_date"]) run_id = post_body["run_id"] dagrun_instance = ( session.query(DagRun) .filter( DagRun.dag_id == dag_id, or_(DagRun.run_id == run_id, DagRun.execution_date == logical_date), ) .first() ) if not dagrun_instance: try: dag = current_app.dag_bag.get_dag(dag_id) dag_run = dag.create_dagrun( run_type=DagRunType.MANUAL, run_id=run_id, execution_date=logical_date, data_interval=dag.timetable.infer_manual_data_interval(run_after=logical_date), state=State.QUEUED, conf=post_body.get("conf"), external_trigger=True, dag_hash=current_app.dag_bag.dags_hash.get(dag_id), ) return dagrun_schema.dump(dag_run) except ValueError as ve: raise BadRequest(detail=str(ve)) if dagrun_instance.execution_date == logical_date: raise AlreadyExists( detail=( f"DAGRun with DAG ID: '{dag_id}' and " f"DAGRun logical date: '{logical_date.isoformat(sep=' ')}' already exists" ), ) raise AlreadyExists(detail=f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: '{run_id}' already exists")
def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Clear task instances.""" body = get_json_request_dict() try: data = clear_task_instance_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) dag = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: error_message = f"Dag id {dag_id} not found" raise NotFound(error_message) reset_dag_runs = data.pop('reset_dag_runs') dry_run = data.pop('dry_run') # We always pass dry_run here, otherwise this would try to confirm on the terminal! dag_run_id = data.pop('dag_run_id', None) future = data.pop('include_future', False) past = data.pop('include_past', False) downstream = data.pop('include_downstream', False) upstream = data.pop('include_upstream', False) if dag_run_id is not None: dag_run: Optional[DR] = ( session.query(DR).filter(DR.dag_id == dag_id, DR.run_id == dag_run_id).one_or_none() ) if dag_run is None: error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' raise NotFound(error_message) data['start_date'] = dag_run.logical_date data['end_date'] = dag_run.logical_date if past: data['start_date'] = None if future: data['end_date'] = None task_ids = data.pop('task_ids', None) if task_ids is not None: task_id = [task[0] if isinstance(task, tuple) else task for task in task_ids] dag = dag.partial_subset( task_ids_or_regex=task_id, include_downstream=downstream, include_upstream=upstream, ) if len(dag.task_dict) > 1: # If we had upstream/downstream etc then also include those! task_ids.extend(tid for tid in dag.task_dict if tid != task_id) task_instances = dag.clear(dry_run=True, dag_bag=get_airflow_app().dag_bag, task_ids=task_ids, **data) if not dry_run: clear_task_instances( task_instances.all(), session, dag=dag, dag_run_state=DagRunState.QUEUED if reset_dag_runs else False, ) return task_instance_reference_collection_schema.dump( TaskInstanceReferenceCollection(task_instances=task_instances.all()) )
def patch_dag(*, dag_id: str, update_mask: UpdateMask = None, session: Session = NEW_SESSION) -> APIResponse: """Update the specific DAG""" try: patch_body = dag_schema.load(request.json, session=session) except ValidationError as err: raise BadRequest(detail=str(err.messages)) if update_mask: patch_body_ = {} if update_mask != ['is_paused']: raise BadRequest(detail="Only `is_paused` field can be updated through the REST API") patch_body_[update_mask[0]] = patch_body[update_mask[0]] patch_body = patch_body_ dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one_or_none() if not dag: raise NotFound(f"Dag with id: '{dag_id}' not found") dag.is_paused = patch_body['is_paused'] session.flush() return dag_schema.dump(dag)
def patch_variable(*, variable_key: str, update_mask: UpdateMask = None) -> Response: """Update a variable by key""" try: data = variable_schema.load(get_json_request_dict()) except ValidationError as err: raise BadRequest("Invalid Variable schema", detail=str(err.messages)) if data["key"] != variable_key: raise BadRequest("Invalid post body", detail="key from request body doesn't match uri parameter") if update_mask: if "key" in update_mask: raise BadRequest("key is a ready only field") if "value" not in update_mask: raise BadRequest("No field to update") Variable.set(data["key"], data["val"]) return variable_schema.dump(data)
def post_variables() -> Response: """Create a variable""" try: data = variable_schema.load(request.json) except ValidationError as err: raise BadRequest("Invalid Variable schema", detail=str(err.messages)) Variable.set(data["key"], data["val"]) return variable_schema.dump(data)
def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION ) -> APIResponse: """Set a state of task instances.""" body = request.get_json() try: data = set_task_instance_state_form.load(body) except ValidationError as err: raise BadRequest(detail=str(err.messages)) error_message = f"Dag ID {dag_id} not found" dag = current_app.dag_bag.get_dag(dag_id) if not dag: raise NotFound(error_message) task_id = data['task_id'] task = dag.task_dict.get(task_id) if not task: error_message = f"Task ID {task_id} not found" raise NotFound(error_message) execution_date = data.get('execution_date') run_id = data.get('dag_run_id') if (execution_date and (session.query(TI).filter( TI.task_id == task_id, TI.dag_id == dag_id, TI.execution_date == execution_date).one_or_none()) is None): raise NotFound( detail= f"Task instance not found for task {task_id!r} on execution_date {execution_date}" ) if run_id and not session.query(TI).get({ 'task_id': task_id, 'dag_id': dag_id, 'run_id': run_id, 'map_index': -1 }): error_message = f"Task instance not found for task {task_id!r} on DAG run with ID {run_id!r}" raise NotFound(detail=error_message) tis = dag.set_task_instance_state( task_id=task_id, dag_run_id=run_id, execution_date=execution_date, state=data["new_state"], upstream=data["include_upstream"], downstream=data["include_downstream"], future=data["include_future"], past=data["include_past"], commit=not data["dry_run"], session=session, ) return task_instance_reference_collection_schema.dump( TaskInstanceReferenceCollection(task_instances=tis))
def post_pool(session): """Create a pool""" required_fields = ["name", "slots" ] # Pool would require both fields in the post request for field in required_fields: if field not in request.json.keys(): raise BadRequest(detail=f"'{field}' is a required property") try: post_body = pool_schema.load(request.json, session=session) except ValidationError as err: raise BadRequest(detail=str(err.messages)) pool = Pool(**post_body) try: session.add(pool) session.commit() return pool_schema.dump(pool) except IntegrityError: raise AlreadyExists(detail=f"Pool: {post_body['pool']} already exists")
def delete_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse: """Delete a pool""" if pool_name == "default_pool": raise BadRequest(detail="Default Pool can't be deleted") affected_count = session.query(Pool).filter( Pool.pool == pool_name).delete() if affected_count == 0: raise NotFound(detail=f"Pool with name:'{pool_name}' not found") return Response(status=HTTPStatus.NO_CONTENT)
def patch_variable(variable_key: str, update_mask: Optional[List[str]] = None) -> Response: """ Update a variable by key """ try: data = variable_schema.load(request.json) except ValidationError as err: raise BadRequest("Invalid Variable schema", detail=str(err.messages)) if data["key"] != variable_key: raise BadRequest("Invalid post body", detail="key from request body doesn't match uri parameter") if update_mask: if "key" in update_mask: raise BadRequest("key is a ready only field") if "value" not in update_mask: raise BadRequest("No field to update") Variable.set(data["key"], data["val"]) return Response(status=204)
def patch_dag(session, dag_id, update_mask=None): """Update the specific DAG""" dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one_or_none() if not dag: raise NotFound(f"Dag with id: '{dag_id}' not found") try: patch_body = dag_schema.load(request.json, session=session) except ValidationError as err: raise BadRequest("Invalid Dag schema", detail=str(err.messages)) if update_mask: patch_body_ = {} if len(update_mask) > 1: raise BadRequest(detail="Only `is_paused` field can be updated through the REST API") update_mask = update_mask[0] if update_mask != 'is_paused': raise BadRequest(detail="Only `is_paused` field can be updated through the REST API") patch_body_[update_mask] = patch_body[update_mask] patch_body = patch_body_ setattr(dag, 'is_paused', patch_body['is_paused']) session.commit() return dag_schema.dump(dag)
def post_pool(session): """Create a pool""" required_fields = {"name", "slots" } # Pool would require both fields in the post request fields_diff = required_fields - set(request.json.keys()) if fields_diff: raise BadRequest( detail=f"Missing required property(ies): {sorted(fields_diff)}") try: post_body = pool_schema.load(request.json, session=session) except ValidationError as err: raise BadRequest(detail=str(err.messages)) pool = Pool(**post_body) try: session.add(pool) session.commit() return pool_schema.dump(pool) except IntegrityError: raise AlreadyExists(detail=f"Pool: {post_body['pool']} already exists")
def autogenerate(self, data, **kwargs): """Auto generate run_id and execution_date if they are not loaded""" if "execution_date" not in data.keys(): data["execution_date"] = str(timezone.utcnow()) if "dag_run_id" not in data.keys(): try: data["dag_run_id"] = DagRun.generate_run_id( DagRunType.MANUAL, timezone.parse(data["execution_date"]) ) except (ParserError, TypeError) as err: raise BadRequest("Incorrect datetime argument", detail=str(err)) return data
def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pattern=None, update_mask=None): """Patch multiple DAGs.""" try: patch_body = dag_schema.load(request.json, session=session) except ValidationError as err: raise BadRequest(detail=str(err.messages)) if update_mask: patch_body_ = {} if update_mask != ['is_paused']: raise BadRequest(detail="Only `is_paused` field can be updated through the REST API") update_mask = update_mask[0] patch_body_[update_mask] = patch_body[update_mask] patch_body = patch_body_ if only_active: dags_query = session.query(DagModel).filter(~DagModel.is_subdag, DagModel.is_active) else: dags_query = session.query(DagModel).filter(~DagModel.is_subdag) if dag_id_pattern == '~': dag_id_pattern = '%' dags_query = dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%')) editable_dags = current_app.appbuilder.sm.get_editable_dag_ids(g.user) dags_query = dags_query.filter(DagModel.dag_id.in_(editable_dags)) if tags: cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags] dags_query = dags_query.filter(or_(*cond)) total_entries = dags_query.count() dags = dags_query.order_by(DagModel.dag_id).offset(offset).limit(limit).all() dags_to_update = {dag.dag_id for dag in dags} session.query(DagModel).filter(DagModel.dag_id.in_(dags_to_update)).update( {DagModel.is_paused: patch_body['is_paused']}, synchronize_session='fetch' ) session.flush() return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries))
def format_datetime(value: str): """ Datetime format parser for args since connexion doesn't parse datetimes https://github.com/zalando/connexion/issues/476 This should only be used within connection views because it raises 400 """ if value[-1] != 'Z': value = value.replace(" ", '+') try: return timezone.parse(value) except (ParserError, TypeError) as err: raise BadRequest("Incorrect datetime argument", detail=str(err))
def patch_connection( *, connection_id: str, update_mask: UpdateMask = None, session: Session = NEW_SESSION, ) -> APIResponse: """Update a connection entry""" try: data = connection_schema.load(request.json, partial=True) except ValidationError as err: # If validation get to here, it is extra field validation. raise BadRequest(detail=str(err.messages)) non_update_fields = ['connection_id', 'conn_id'] connection = session.query(Connection).filter_by( conn_id=connection_id).first() if connection is None: raise NotFound( "Connection not found", detail= f"The Connection with connection_id: `{connection_id}` was not found", ) if data.get('conn_id') and connection.conn_id != data['conn_id']: raise BadRequest(detail="The connection_id cannot be updated.") if update_mask: update_mask = [i.strip() for i in update_mask] data_ = {} for field in update_mask: if field in data and field not in non_update_fields: data_[field] = data[field] else: raise BadRequest( detail=f"'{field}' is unknown or cannot be updated.") data = data_ for key in data: setattr(connection, key, data[key]) session.add(connection) session.commit() return connection_schema.dump(connection)