def validate_wrapper(self, request: Message, context: grpc.ServicerContext): errors = [] for field_name in field_names: field_validators: List[AbstractArgumentValidator] = [] is_optional = (field_name in optional_non_empty_value or field_name in optional_uuids_value or field_name in optional_non_default_value or field_name in optional_validators_value) if field_name in uuids_value + optional_uuids_value: field_validators.append(UUIDBytesValidator()) if field_name in non_empty_value + optional_non_empty_value: field_validators.append(NonEmptyValidator()) if field_name in non_default_value + optional_non_default_value: field_validators.append(NonDefaultValidator()) if field_name in itertools.chain( validators_value.keys(), optional_validators_value.keys()): validator = { **validators_value, **optional_validators_value }.get(field_name) if validator is not None: field_validators.append(validator) errors.extend( _recurse_validate(request, name=field_name, validators=field_validators, is_optional=is_optional)) if len(errors) > 0: context.abort(grpc.StatusCode.INVALID_ARGUMENT, ", ".join(errors)[:1000]) return func(self, request, context)
def Process(self, fs_msg: common_pb2.Message, context: grpc.ServicerContext): """Processes a single fleetspeak message.""" try: validation_info = dict(fs_msg.validation_info.tags) if fs_msg.message_type == "GrrMessage": grr_message = rdf_flows.GrrMessage.FromSerializedBytes( fs_msg.data.value) self._ProcessGRRMessages(fs_msg.source.client_id, [grr_message], validation_info) elif fs_msg.message_type == "MessageList": packed_messages = rdf_flows.PackedMessageList.FromSerializedBytes( fs_msg.data.value) message_list = communicator.Communicator.DecompressMessageList( packed_messages) self._ProcessGRRMessages(fs_msg.source.client_id, message_list.job, validation_info) else: logging.error( "Received message with unrecognized message_type: %s", fs_msg.message_type) context.set_code(grpc.StatusCode.INVALID_ARGUMENT) except Exception: logging.exception("Exception processing message: %s", fs_msg) raise
def DetectLanguageFromTweetText( self, request: DetectLanguageFromTweetTextRequest, context: grpc.ServicerContext ) -> DetectLanguageFromTweetTextResponse: """ Q: "à partir du texte d’un tweet, est-il possible de deviner la langue dans lequel le tweet a été rédigé ?" Args: request: context: Returns: """ # process feature: detect language from tweet id with StorageDatabase() as db: tweet = db.tweets.find_one({'tweet_id': str(request.tweet_id)}) if tweet is None: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details("Can't find a tweet with tweet_id={} !".format(request.tweet_id)) return DetectLanguageFromTweetTextResponse() tweet_text = tweet['text'] logger.debug(f"tweet_id={request.tweet_id} => text: {tweet_text}") detect_language = compute_detect_language(tweet_text) # build message result msg_detect_language = DetectLanguageFromTweetTextResponse(**{**detect_language.dict(), **{'text': tweet_text}}) # send message result (one_to_one) return msg_detect_language
def Heartbeat( self, request: coordinator_pb2.HeartbeatRequest, context: grpc.ServicerContext) -> coordinator_pb2.HeartbeatReply: """The Heartbeat gRPC method. Participants periodically send an heartbeat so that the :class:`Coordinator` can detect failures. Args: request (:class:`~.coordinator_pb2.HeartbeatRequest`): The participant's request. The participant's request contains the current :class:`~.coordinator_pb2.State` and round number the participant is on. context (:class:`~grpc.ServicerContext`): The context associated with the gRPC request. Returns: :class:`~.coordinator_pb2.HeartbeatReply`: The reply to the participant's request. The reply contains both the :class:`~.coordinator_pb2.State` and the current round the coordinator is on. If a training session has not started yet the round number defaults to 0. """ try: return self.coordinator.on_message(request, context.peer()) except UnknownParticipantError as error: context.set_details(str(error)) context.set_code(grpc.StatusCode.PERMISSION_DENIED) return coordinator_pb2.HeartbeatReply()
def Tx(self, request_iterator: Iterator[TxRequest], context: grpc.ServicerContext) -> Iterator[TxResponse]: self._requests = [] for request in request_iterator: print(f"REQUEST: {request}") self._requests.append(request) for mock_response in self._responses: tx_response = mock_response.test(request) if tx_response is not None: print(f"RESPONSE: {tx_response}") self._responses.remove(mock_response) if isinstance(tx_response, TxResponse): yield tx_response else: # return an error message context.set_trailing_metadata([("ErrorType", error_type)]) context.abort(grpc.StatusCode.UNKNOWN, tx_response) break else: print(f"RESPONSE: {DONE}") yield DONE
def intercept(self, method: Callable, request: Any, context: grpc.ServicerContext, method_name: str) -> Any: """Override this method to implement a custom interceptor. You should call method(request, context) to invoke the next handler (either the RPC method implementation, or the next interceptor in the list). Args: method: The next interceptor, or method implementation. request: The RPC request, as a protobuf message. context: The ServicerContext pass by gRPC to the service. method_name: A string of the form "/protobuf.package.Service/Method" Returns: This should generally return the result of method(request, context), which is typically the RPC method response, as a protobuf message. The interceptor is free to modify this in some way, however. """ try: ip = context.peer().split(':')[1] print( f'RPC: {method_name} -- {ip} -- {datetime.now().strftime(DATE_FORMAT)}' ) return method(request, context) except GrpcException as e: print( f'RPC: {method_name} -- {ip} -- {datetime.now().strftime(DATE_FORMAT)} -- {request}' ) context.abort(e.status_code, e.details) raise except Exception as e: template = "An exception of type {0} occurred. Arguments:\n{1!r}" message = template.format(type(e).__name__, e.args) print(message) raise
def forbid_all( attrs: t.Collection[str], obj: object, ctx: grpc.ServicerContext, parent: str = "request", ) -> None: """Verify that no illegal combination of arguments is provided by the client. Args: attrs: Names of parameters which cannot occur at the same time. obj: Current request message (e.g., a subclass of google.protobuf.message.Message). ctx: Current gRPC context. parent: Name of the parent message. Only used to compose more helpful error. """ func = attrgetter(*attrs) attr_result = func(obj) if len(attrs) == 1: attr_result = [attr_result] if all(attr_result): ctx.abort( grpc.StatusCode.INVALID_ARGUMENT, f"The message '{parent}' is not allowed to allowed to have the following parameter combination: {attrs}.", )
def require_all( attrs: t.Collection[str], obj: object, ctx: grpc.ServicerContext, parent: str = "request", ) -> None: """Verify that all required arguments are supplied by the client. If arguments are missing, the context will be aborted Args: attrs: Names of the required parameters. obj: Current request message (e.g., a subclass of google.protobuf.message.Message). ctx: Current gRPC context. parent: Name of the parent message. Only used to compose more helpful error. """ func = attrgetter(*attrs) attr_result = func(obj) if len(attrs) == 1: attr_result = [attr_result] if not all(attr_result): ctx.abort( grpc.StatusCode.INVALID_ARGUMENT, f"The message '{parent}' requires the following attributes: {attrs}.", )
async def rpc_ping(self, request: dht_pb2.PingRequest, context: grpc.ServicerContext): """ Some node wants us to add it to our routing table. """ response = dht_pb2.PingResponse(peer=self.node_info, sender_endpoint=context.peer(), dht_time=get_dht_time(), available=False) if request.peer and request.peer.node_id and request.peer.rpc_port: sender_id = DHTID.from_bytes(request.peer.node_id) if request.peer.endpoint != dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value: sender_endpoint = request.peer.endpoint # if peer has preferred endpoint, use it else: sender_endpoint = replace_port(context.peer(), new_port=request.peer.rpc_port) response.sender_endpoint = sender_endpoint if request.validate: response.available = await self.call_ping( response.sender_endpoint, validate=False) == sender_id asyncio.create_task( self.update_routing_table(sender_id, sender_endpoint, responded=response.available or not request.validate)) return response
def StreamCall( self, request_iterator: Iterable[phone_pb2.StreamCallRequest], context: grpc.ServicerContext ) -> Iterable[phone_pb2.StreamCallResponse]: try: request = next(request_iterator) logging.info("Received a phone call request for number [%s]", request.phone_number) except StopIteration: raise RuntimeError("Failed to receive call request") # Simulate the acceptance of call request time.sleep(1) yield create_state_response(phone_pb2.CallState.NEW) # Simulate the start of the call session time.sleep(1) call_info = self._create_call_session() context.add_callback(lambda: self._clean_call_session(call_info)) response = phone_pb2.StreamCallResponse() response.call_info.session_id = call_info.session_id response.call_info.media = call_info.media yield response yield create_state_response(phone_pb2.CallState.ACTIVE) # Simulate the end of the call time.sleep(2) yield create_state_response(phone_pb2.CallState.ENDED) logging.info("Call finished [%s]", request.phone_number)
def UnaryCall( self, request: messages_pb2.SimpleRequest, context: grpc.ServicerContext) -> messages_pb2.SimpleResponse: context.send_initial_metadata((('hostname', self._hostname), )) response = messages_pb2.SimpleResponse() response.server_id = self._server_id response.hostname = self._hostname return response
def context_abort_with_exception_traceback(context: grpc.ServicerContext, exc: Exception, status_code: grpc.StatusCode): context.abort( code=status_code, details=(f'Exception Type: {type(exc)} \n' f'Exception Message: {exc} \n' f'Traceback: \n {traceback.format_tb(exc.__traceback__)}'))
def set_grpc_err(context: grpc.ServicerContext, code: grpc.StatusCode, details: str): """ Sets status code and details for a gRPC context. Removes commas from the details message (see https://github.com/grpc/grpc-node/issues/769) """ context.set_code(code) context.set_details(details.replace(',', ''))
def DeleteTopic(self, request: pubsub_pb2.DeleteTopicRequest, context: grpc.ServicerContext): # noqa: D403 """DeleteTopic implementation.""" self.logger.debug("DeleteTopic(%s)", LazyFormat(request)) try: self.topics.pop(request.topic) except KeyError: context.abort(grpc.StatusCode.NOT_FOUND, "Topic not found") return empty_pb2.Empty()
def CreateTopic( self, request: pubsub_pb2.Topic, context: grpc.ServicerContext ): # noqa: D403 """CreateTopic implementation.""" self.logger.debug("CreateTopic(%s)", LazyFormat(request)) if request.name in self.topics: context.abort(grpc.StatusCode.ALREADY_EXISTS, "Topic already exists") self.topics[request.name] = set() return request
def DeleteItem(self, request: api_pb2.DeleteItemRequest, context: grpc.ServicerContext) -> api_pb2.DeleteItemResponse: try: db.models.Item.objects.get(id=request.id).delete() return api_pb2.DeleteItemResponse() except db.models.Item.DoesNotExist: context.set_code(grpc.StatusCode.NOT_FOUND) context.set_details('Item does not exist') return api_pb2.DeleteItemResponse()
def Query(self, query: warehouse_pb2.ProductQuery, context: grpc.ServicerContext) -> warehouse_pb2.Product: logger.info(f"Query product id={query.id}") product = self._products.query(product_id=query.id) if not product: not_found_message = f"Product: {query.id} not found!" logger.info(not_found_message) context.abort(grpc.StatusCode.NOT_FOUND, not_found_message) return self._products.query(product_id=query.id)
def SayHello(self, request: helloworld_pb2.HelloRequest, context: grpc.ServicerContext) -> helloworld_pb2.HelloReply: parts = request.name.split("=") err_length = int(parts[1]) err_msg = "x" * err_length context.set_code(grpc.StatusCode.FAILED_PRECONDITION) context.set_details(err_msg) logging.info("returning message length = %d", len(err_msg)) return helloworld_pb2.HelloReply()
def wrapped_set_logging_context(self, request, context: grpc.ServicerContext): obj = request_log_context_extractors[type(request)](request) meta = {} for md in context.invocation_metadata(): meta[md.key] = md.value obj["meta"] = meta obj["peer"] = context.peer() slogging.set_context(obj) self._log.info("new %s", type(request).__name__) return func(self, request, context)
def AddItem(self, request: api_pb2.Item, context: grpc.ServicerContext) -> api_pb2.AddItemResponse: try: item = db.models.Item(**_json_format.MessageToDict(request)) item.full_clean() item.save() return api_pb2.AddItemResponse(item=convert(item)) except ValidationError as e: # TODO: Return correct status code by exception types context.set_code(grpc.StatusCode.INTERNAL) context.set_details(json.dumps(e.messages)) return api_pb2.AddItemResponse()
def get_mobility_node(session: Session, node_id: int, context: ServicerContext) -> Union[WlanNode, EmaneNet]: try: return session.get_node(node_id, WlanNode) except CoreError: try: return session.get_node(node_id, EmaneNet) except CoreError: context.abort(grpc.StatusCode.NOT_FOUND, "node id is not for wlan or emane")
def StartTraining( self, request: coordinator_pb2.StartTrainingRequest, context: grpc.ServicerContext, ) -> coordinator_pb2.StartTrainingReply: """The StartTraining gRPC method. Once a participant is notified that the :class:`xain_fl.coordinator.coordinator.Coordinator` is in a round (through the state advertised in the :class:`~.coordinator_pb2.HeartbeatReply`), the participant should call this method in order to get the global model weights in order to start the training for the round. Args: request (:class:`~.coordinator_pb2.StartTrainingRequest`): The participant's request. context (:class:`~grpc.ServicerContext`): The context associated with the gRPC request. Returns: :class:`~.coordinator_pb2.StartTrainingReply`: The reply to the participant's request. The reply contains the global model weights. """ try: return self.coordinator.on_message(request, context.peer()) except UnknownParticipantError as error: context.set_details(str(error)) context.set_code(grpc.StatusCode.PERMISSION_DENIED) return coordinator_pb2.StartTrainingReply() except InvalidRequestError as error: context.set_details(str(error)) context.set_Code(grpc.StatusCode.FAILED_PRECONDITION) return coordinator_pb2.StartTrainingReply()
def EndTraining( self, request: coordinator_pb2.EndTrainingRequest, context: grpc.ServicerContext) -> coordinator_pb2.EndTrainingReply: """The EndTraining gRPC method. Once a participant has finished the training for the round it calls this method in order to submit to the :class:`xain_fl.coordinator.coordinator.Coordinator` the updated weights. Args: request (:class:`~.coordinator_pb2.EndTrainingRequest`): The participant's request. The request contains the updated weights as a result of the training as well as any metrics helpful for the :class:`xain_fl.coordinator.coordinator.Coordinator`. context (:class:`~grpc.ServicerContext`): The context associated with the gRPC request. Returns: :class:`~.coordinator_pb2.EndTrainingReply`: The reply to the participant's request. The reply is just an acknowledgment that the :class:`xain_fl.coordinator.coordinator.Coordinator` successfully received the updated weights. """ try: return self.coordinator.on_message(request, context.peer()) except DuplicatedUpdateError as error: context.set_details(str(error)) context.set_code(grpc.StatusCode.ALREADY_EXISTS) return coordinator_pb2.EndTrainingReply() except UnknownParticipantError as error: context.set_details(str(error)) context.set_code(grpc.StatusCode.PERMISSION_DENIED) return coordinator_pb2.EndTrainingReply()
def executor(self, request: Any, context: grpc.ServicerContext) -> executor_base.Executor: """Returns the executor which should be used to handle `request`.""" with self._lock: executor = self._executors.get(request.executor.id) if executor is None: message = f'No executor found for executor id: {request.executor.id}' context.set_code(grpc.StatusCode.FAILED_PRECONDITION) context.set_details(message) raise RuntimeError(message) return executor
def login(self, request: service_pb2.LoginRequest, context: grpc.ServicerContext) -> service_pb2.Empty: print(request.email) print(request.password) if not bcrypt.checkpw(request.password, ""): context.set_code(grpc.StatusCode.PERMISSION_DENIED) context.set_details("Invalid password") return service_pb2.Empty() return
def Get(self, request: app_pb2.GetReq, context: grpc.ServicerContext) -> app_pb2.GetResp: # Handle Error # https://github.com/grpc/grpc/blob/master/doc/statuscodes.md if request.key not in self.store: context.set_code(grpc.StatusCode.NOT_FOUND) context.set_details('Key not found {}'.format(request.key)) return app_pb2.GetResp() val = self.store[request.key] return app_pb2.GetResp(val=val)
def GetItem(self, request: api_pb2.GetItemRequest, context: grpc.ServicerContext) -> api_pb2.Item: try: item = db.models.Item.objects.get(id=request.id) item.pv = F('pv') + 1 item.save() item.refresh_from_db() return convert(item) except db.models.Item.DoesNotExist: context.set_code(grpc.StatusCode.NOT_FOUND) context.set_details('Item does not exist') return api_pb2.Item()
def GetBeer(self, get_beer_request: bartender_pb2.GetBeerRequest, context: grpc.ServicerContext) -> bartender_pb2.Beer: types = self._cache.get(get_beer_request.brand, []) if not types: context.set_code(grpc.StatusCode.NOT_FOUND) return bartender_pb2.Beer() beer = [b for b in types if b.name == get_beer_request.name] if not beer: context.set_code(grpc.StatusCode.NOT_FOUND) return bartender_pb2.Beer() return beer[0]
def Connect(self, request: mw_protocols.ClientConnectRequest, context: grpc.ServicerContext) \ -> Generator[mw_protocols.Command, None, None]: """ Called by a client when it connects to the Avocado framework. Preserves a channel with client for the whole time when client is accessible. Note that the connection may break from client side. :return: stream of commands """ def set_dependency_command(): command = mw_protocols.Command() command.SET_DEPENDENCY.name = dependency command.SET_DEPENDENCY.ip = ip logging.info( f"Sent address of dependency {dependency} to client " f"{request.application}:{request.clientType}:{id_}. Value = {ip}" ) client.last_call = time.perf_counter() return command # Get IP from request peer: str = context.peer() client_ip: str = self._parse_ip_from_peer(peer) # Prepare the client descriptor: client_descriptor = protocols.ClientDescriptor( application=request.application, type=request.clientType, ip=client_ip) if request.establishedConnection: client_descriptor.hasID = True client_descriptor.persistent_id = request.id rc, id_, client = self._client_model.assign_client( client_descriptor, context) if rc == mw_protocols.ClientResponseCode.Value("OK"): if not request.establishedConnection: yield mw_protocols.Command(SET_ID=id_) else: for dependency, ip in client.dependencies.items(): yield set_dependency_command() while True: try: while len(client.dependency_updates) == 0: time.sleep(self._client_model.wait_signal_frequency) client.last_call = time.perf_counter() yield mw_protocols.Command(WAIT=0) dependency, ip = client.dependency_updates.popitem() yield set_dependency_command() except Exception: context.cancel() else: yield mw_protocols.Command(ERROR=rc) return
def handle_except(ex: Exception, ctx: grpc.ServicerContext) -> None: """Handler that can be called when handling an exception. It will pass the traceback to the gRPC client and abort the context. Args: ex: Exception that occured. ctx: Current gRPC context. """ msg = "".join(traceback.TracebackException.from_exception(ex).format()) ctx.abort(grpc.StatusCode.UNKNOWN, msg)
def AddToList(self, request: models.AddToListRequest, context: grpc.ServicerContext) -> models.Item: items = self.todo_lists.get(request.list_name, None) if items is None: context.abort(grpc.StatusCode.INVALID_ARGUMENT, "List does not exist") item_id = uuid4() item = Item(item_id, request.value, ItemStatus.COMPLETE) self.todo_lists[request.list_name][item_id] = item return models.Item(item_id=str(item.id), value=item.value, status=item.status.value)
def CreateSubscription( self, request: pubsub_pb2.Subscription, context: grpc.ServicerContext ): # noqa: D403 """CreateSubscription implementation.""" self.logger.debug("CreateSubscription(%s)", LazyFormat(request)) if request.name in self.subscriptions: context.abort(grpc.StatusCode.ALREADY_EXISTS, "Subscription already exists") elif request.topic not in self.topics: context.abort(grpc.StatusCode.NOT_FOUND, "Topic not found") subscription = Subscription() self.subscriptions[request.name] = subscription self.topics[request.topic].add(subscription) return request
def CreateList(self, request: models.CreateListRequest, context: grpc.ServicerContext) -> models.Empty: name = request.name if not name: context.abort(grpc.StatusCode.INVALID_ARGUMENT, "List name cannot be empty") if name in self.todo_lists: context.abort(grpc.StatusCode.ALREADY_EXISTS, "List already exists.") self.todo_lists[name] = {} return models.Empty()
def ListItem(self, request: api_pb2.ListItemRequest, context: grpc.ServicerContext) -> api_pb2.ListItemResponse: page = request.page limit = request.limit items = Paginator(db.models.Item.objects.all(), limit) try: return api_pb2.ListItemResponse( items=map(convert, items.page(page).object_list), total=db.models.Item.objects.count(), prevPage=max(1, page - 1), nextPage=min(page + 1, items.num_pages), ) except InvalidPage: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details('Invalid page number') return api_pb2.ListItemResponse()
def UpdateItem(self, request: api_pb2.UpdateItemRequest, context: grpc.ServicerContext) -> api_pb2.UpdateItemResponse: try: item = db.models.Item.objects.get(id=request.item.id) item.__dict__.update(_json_format.MessageToDict(request)['item']) item.full_clean() item.save() item.refresh_from_db() return api_pb2.UpdateItemResponse(item=convert(item)) except db.models.Item.DoesNotExist: context.set_code(grpc.StatusCode.NOT_FOUND) context.set_details('Item does not exist') return api_pb2.UpdateItemResponse() except ValidationError as e: # TODO: Return correct status code by exception types context.set_code(grpc.StatusCode.INTERNAL) context.set_details(json.dumps(e.messages)) return api_pb2.UpdateItemResponse()