def Tx(self, request_iterator: Iterator[TxRequest], context: grpc.ServicerContext) -> Iterator[TxResponse]: self._requests = [] for request in request_iterator: print(f"REQUEST: {request}") self._requests.append(request) for mock_response in self._responses: tx_response = mock_response.test(request) if tx_response is not None: print(f"RESPONSE: {tx_response}") self._responses.remove(mock_response) if isinstance(tx_response, TxResponse): yield tx_response else: # return an error message context.set_trailing_metadata([("ErrorType", error_type)]) context.abort(grpc.StatusCode.UNKNOWN, tx_response) break else: print(f"RESPONSE: {DONE}") yield DONE
def intercept(self, method: Callable, request: Any, context: grpc.ServicerContext, method_name: str) -> Any: """Override this method to implement a custom interceptor. You should call method(request, context) to invoke the next handler (either the RPC method implementation, or the next interceptor in the list). Args: method: The next interceptor, or method implementation. request: The RPC request, as a protobuf message. context: The ServicerContext pass by gRPC to the service. method_name: A string of the form "/protobuf.package.Service/Method" Returns: This should generally return the result of method(request, context), which is typically the RPC method response, as a protobuf message. The interceptor is free to modify this in some way, however. """ try: ip = context.peer().split(':')[1] print( f'RPC: {method_name} -- {ip} -- {datetime.now().strftime(DATE_FORMAT)}' ) return method(request, context) except GrpcException as e: print( f'RPC: {method_name} -- {ip} -- {datetime.now().strftime(DATE_FORMAT)} -- {request}' ) context.abort(e.status_code, e.details) raise except Exception as e: template = "An exception of type {0} occurred. Arguments:\n{1!r}" message = template.format(type(e).__name__, e.args) print(message) raise
def validate_wrapper(self, request: Message, context: grpc.ServicerContext): errors = [] for field_name in field_names: field_validators: List[AbstractArgumentValidator] = [] is_optional = (field_name in optional_non_empty_value or field_name in optional_uuids_value or field_name in optional_non_default_value or field_name in optional_validators_value) if field_name in uuids_value + optional_uuids_value: field_validators.append(UUIDBytesValidator()) if field_name in non_empty_value + optional_non_empty_value: field_validators.append(NonEmptyValidator()) if field_name in non_default_value + optional_non_default_value: field_validators.append(NonDefaultValidator()) if field_name in itertools.chain( validators_value.keys(), optional_validators_value.keys()): validator = { **validators_value, **optional_validators_value }.get(field_name) if validator is not None: field_validators.append(validator) errors.extend( _recurse_validate(request, name=field_name, validators=field_validators, is_optional=is_optional)) if len(errors) > 0: context.abort(grpc.StatusCode.INVALID_ARGUMENT, ", ".join(errors)[:1000]) return func(self, request, context)
def require_all( attrs: t.Collection[str], obj: object, ctx: grpc.ServicerContext, parent: str = "request", ) -> None: """Verify that all required arguments are supplied by the client. If arguments are missing, the context will be aborted Args: attrs: Names of the required parameters. obj: Current request message (e.g., a subclass of google.protobuf.message.Message). ctx: Current gRPC context. parent: Name of the parent message. Only used to compose more helpful error. """ func = attrgetter(*attrs) attr_result = func(obj) if len(attrs) == 1: attr_result = [attr_result] if not all(attr_result): ctx.abort( grpc.StatusCode.INVALID_ARGUMENT, f"The message '{parent}' requires the following attributes: {attrs}.", )
def forbid_all( attrs: t.Collection[str], obj: object, ctx: grpc.ServicerContext, parent: str = "request", ) -> None: """Verify that no illegal combination of arguments is provided by the client. Args: attrs: Names of parameters which cannot occur at the same time. obj: Current request message (e.g., a subclass of google.protobuf.message.Message). ctx: Current gRPC context. parent: Name of the parent message. Only used to compose more helpful error. """ func = attrgetter(*attrs) attr_result = func(obj) if len(attrs) == 1: attr_result = [attr_result] if all(attr_result): ctx.abort( grpc.StatusCode.INVALID_ARGUMENT, f"The message '{parent}' is not allowed to allowed to have the following parameter combination: {attrs}.", )
def context_abort_with_exception_traceback(context: grpc.ServicerContext, exc: Exception, status_code: grpc.StatusCode): context.abort( code=status_code, details=(f'Exception Type: {type(exc)} \n' f'Exception Message: {exc} \n' f'Traceback: \n {traceback.format_tb(exc.__traceback__)}'))
def CreateTopic( self, request: pubsub_pb2.Topic, context: grpc.ServicerContext ): # noqa: D403 """CreateTopic implementation.""" self.logger.debug("CreateTopic(%s)", LazyFormat(request)) if request.name in self.topics: context.abort(grpc.StatusCode.ALREADY_EXISTS, "Topic already exists") self.topics[request.name] = set() return request
def DeleteTopic(self, request: pubsub_pb2.DeleteTopicRequest, context: grpc.ServicerContext): # noqa: D403 """DeleteTopic implementation.""" self.logger.debug("DeleteTopic(%s)", LazyFormat(request)) try: self.topics.pop(request.topic) except KeyError: context.abort(grpc.StatusCode.NOT_FOUND, "Topic not found") return empty_pb2.Empty()
def Query(self, query: warehouse_pb2.ProductQuery, context: grpc.ServicerContext) -> warehouse_pb2.Product: logger.info(f"Query product id={query.id}") product = self._products.query(product_id=query.id) if not product: not_found_message = f"Product: {query.id} not found!" logger.info(not_found_message) context.abort(grpc.StatusCode.NOT_FOUND, not_found_message) return self._products.query(product_id=query.id)
def get_mobility_node(session: Session, node_id: int, context: ServicerContext) -> Union[WlanNode, EmaneNet]: try: return session.get_node(node_id, WlanNode) except CoreError: try: return session.get_node(node_id, EmaneNet) except CoreError: context.abort(grpc.StatusCode.NOT_FOUND, "node id is not for wlan or emane")
def handle_except(ex: Exception, ctx: grpc.ServicerContext) -> None: """Handler that can be called when handling an exception. It will pass the traceback to the gRPC client and abort the context. Args: ex: Exception that occured. ctx: Current gRPC context. """ msg = "".join(traceback.TracebackException.from_exception(ex).format()) ctx.abort(grpc.StatusCode.UNKNOWN, msg)
def CreateSubscription( self, request: pubsub_pb2.Subscription, context: grpc.ServicerContext ): # noqa: D403 """CreateSubscription implementation.""" self.logger.debug("CreateSubscription(%s)", LazyFormat(request)) if request.name in self.subscriptions: context.abort(grpc.StatusCode.ALREADY_EXISTS, "Subscription already exists") elif request.topic not in self.topics: context.abort(grpc.StatusCode.NOT_FOUND, "Topic not found") subscription = Subscription() self.subscriptions[request.name] = subscription self.topics[request.topic].add(subscription) return request
def AddToList(self, request: models.AddToListRequest, context: grpc.ServicerContext) -> models.Item: items = self.todo_lists.get(request.list_name, None) if items is None: context.abort(grpc.StatusCode.INVALID_ARGUMENT, "List does not exist") item_id = uuid4() item = Item(item_id, request.value, ItemStatus.COMPLETE) self.todo_lists[request.list_name][item_id] = item return models.Item(item_id=str(item.id), value=item.value, status=item.status.value)
def CreateList(self, request: models.CreateListRequest, context: grpc.ServicerContext) -> models.Empty: name = request.name if not name: context.abort(grpc.StatusCode.INVALID_ARGUMENT, "List name cannot be empty") if name in self.todo_lists: context.abort(grpc.StatusCode.ALREADY_EXISTS, "List already exists.") self.todo_lists[name] = {} return models.Empty()
def wrapper(self, request, context: ServicerContext): metadata = context.invocation_metadata() token = next((x for x in metadata if x[0] == config.token_header_key), None) if token is None: context.abort(StatusCode.UNAUTHENTICATED, "Permission denied") token = token.value.split(token_header)[1] claims = get_token_claims(token) if claims is None: context.abort(StatusCode.UNAUTHENTICATED, "Permission denied") else: return func(self, request, context, claims=claims)
def wrapped_catch_them_all(self, request, context: grpc.ServicerContext): try: return func(self, request, context) except Exception as e: start_time = getattr(request, "start_time", None) if start_time is not None: delta = time.perf_counter() - start_time self._log.exception("FAIL %.3f", delta) else: self._log.exception("FAIL ?") context.abort(grpc.StatusCode.INTERNAL, "%s: %s" % (type(e), e))
def Acknowledge(self, request: pubsub_pb2.AcknowledgeRequest, context: grpc.ServicerContext): """Acknowledge implementation.""" self.logger.debug("Acknowledge(%s)", LazyFormat(request)) try: subscription = self.subscriptions[request.subscription] except KeyError: context.abort(grpc.StatusCode.NOT_FOUND, "Subscription not found") for ack_id in request.ack_ids: try: subscription.pulled.pop(ack_id) except KeyError: context.abort(grpc.StatusCode.NOT_FOUND, "Ack ID not found") return empty_pb2.Empty()
def permission_denied(self, request, context: grpc.ServicerContext, message=None, code=None): """ If request is not permitted, determine what kind of exception to raise. """ if request.authenticators and not request.successful_authenticator: context.abort(grpc.StatusCode.UNAUTHENTICATED, details=message if message else '') # raise exceptions.NotAuthenticated() context.abort(grpc.StatusCode.PERMISSION_DENIED, details=message if message else '')
def DeleteSubscription( self, request: pubsub_pb2.DeleteSubscriptionRequest, context: grpc.ServicerContext, ): # noqa: D403 """DeleteSubscription implementation.""" self.logger.debug("DeleteSubscription(%s)", LazyFormat(request)) try: subscription = self.subscriptions.pop(request.subscription) except KeyError: context.abort(grpc.StatusCode.NOT_FOUND, "Subscription not found") for subscriptions in self.topics.values(): subscriptions.discard(subscription) return empty_pb2.Empty()
def _check_audio(samples: Sequence[float], context: grpc.ServicerContext): """ Check provided audio samples for validity. Returns gRPC errors (using ``context``) if not. :param samples: audio samples. :param context: gRPC context. """ if not len(samples): context.abort(grpc.StatusCode.INVALID_ARGUMENT, "no audio samples provided") if len(samples) > 16000: context.abort(grpc.StatusCode.INVALID_ARGUMENT, "too many audio samples provided")
def UpdateTopic( self, request: pubsub_pb2.UpdateTopicRequest, context: grpc.ServicerContext ): """Repurpose UpdateTopic API for setting up test conditions. :param request.topic.name: Name of the topic that needs overrides. :param request.update_mask.paths: A list of overrides, of the form "key=value". Valid override keys are "status_code" and "sleep". An override value of "" disables the override. For the override key "status_code" the override value indicates the status code that should be returned with an empty response by Publish requests, and non-empty override values must be a property of `grpc.StatusCode` such as "UNIMPLEMENTED". For the override key "sleep" the override value indicates a number of seconds Publish requests should sleep before returning, and non-empty override values must be a valid float. Publish requests will return a valid response without recording messages. """ self.logger.debug("UpdateTopic(%s)", LazyFormat(request)) for override in request.update_mask.paths: key, value = override.split("=", 1) if key.lower() in ("status_code", "statuscode"): if value: try: self.status_codes[request.topic.name] = getattr( grpc.StatusCode, value.upper() ) except AttributeError: context.abort( grpc.StatusCode.INVALID_ARGUMENT, "Invalid status code" ) else: try: del self.status_codes[request.topic.name] except KeyError: context.abort( grpc.StatusCode.NOT_FOUND, "Status code override not found" ) elif key.lower() == "sleep": if value: try: self.sleep = float(value) except ValueError: context.abort( grpc.StatusCode.INVALID_ARGUMENT, "Invalid sleep time" ) else: self.sleep = None else: context.abort(grpc.StatusCode.Not_FOUND, "Path not found") return request.topic
def get_nem_id(node: CoreNode, netif_id: int, context: ServicerContext) -> int: """ Get nem id for a given node and interface id. :param node: node to get nem id for :param netif_id: id of interface on node to get nem id for :param context: request context :return: nem id """ netif = node.netif(netif_id) if not netif: message = f"{node.name} missing interface {netif_id}" context.abort(grpc.StatusCode.NOT_FOUND, message) net = netif.net if not isinstance(net, EmaneNet): message = f"{node.name} interface {netif_id} is not an EMANE network" context.abort(grpc.StatusCode.INVALID_ARGUMENT, message) return net.getnemid(netif)
def Put(self, request_iterator: connectme_pb2.FileChunk, context: grpc.ServicerContext): """Write/create the incoming files""" total_files: int = 0 total_bytes: int = 0 try: # paths are with respect to server side fs (total_files, total_bytes) = self.fileChunkReceiver(request_iterator, False) except IsADirectoryError as e: details = "Tried to write to directory: \"{}\"".format(e) logging.warn(details) context.abort(grpc.StatusCode.INVALID_ARGUMENT, details) except Exception as e: details = "Unknown error on open or write: \"{}\"".format(e) logging.warn(details) context.abort(grpc.StatusCode.UNKNOWN, details) return connectme_pb2.PutReturn(total_files=total_files, total_bytes=total_bytes)
def get(self, request: BookGetRequest, context: grpc.ServicerContext) -> Book: LOG.info("BookRpc get method", book_id=request.id if hasattr(request, 'id') else None) if request.id not in self.state: return context.abort( grpc.StatusCode.INVALID_ARGUMENT, f"Book with id [{request.id}] does not exists!") return BookGetReply(book=self.state[request.id])
def Publish(self, request: pubsub_pb2.PublishRequest, context: grpc.ServicerContext): """Publish implementation.""" self.logger.debug("Publish(%.100s)", LazyFormat(request)) if request.topic in self.status_codes: context.abort(self.status_codes[request.topic], "Override") if self.sleep is not None: time.sleep(self.sleep) message_ids: List[str] = [] try: subscriptions = self.topics[request.topic] except KeyError: context.abort(grpc.StatusCode.NOT_FOUND, "Topic not found") message_ids = [uuid.uuid4().hex for _ in request.messages] for _id, message in zip(message_ids, request.messages): message.message_id = _id for subscription in subscriptions: subscription.published.extend(request.messages) return pubsub_pb2.PublishResponse(message_ids=message_ids)
def Get(self, request_iterator: connectme_pb2.FilePath, context: grpc.ServicerContext): """Fetch the incoming file patterns""" for file in request_iterator: try: # paths are with respect to server side fs paths = self.expandPath(file.path) if len(paths) == 0: raise FileNotFoundError(file.path) for c in self.fileChunkGenerator(paths, False): yield c except FileNotFoundError as e: details = "File not found: \"{}\"".format(e) logging.warn(details) context.abort(grpc.StatusCode.NOT_FOUND, details) except Exception as e: details = "Unknown error on open or read: \"{}\"".format(e) logging.warn(details) context.abort(grpc.StatusCode.UNKNOWN, details)
def configure_node(session: Session, node: core_pb2.Node, core_node: NodeBase, context: ServicerContext) -> None: for emane_config in node.emane_configs: _id = utils.iface_config_id(node.id, emane_config.iface_id) config = {k: v.value for k, v in emane_config.config.items()} session.emane.set_config(_id, emane_config.model, config) if node.wlan_config: config = {k: v.value for k, v in node.wlan_config.items()} session.mobility.set_model_config(node.id, BasicRangeModel.name, config) if node.mobility_config: config = {k: v.value for k, v in node.mobility_config.items()} session.mobility.set_model_config(node.id, Ns2ScriptedMobility.name, config) for service_name, service_config in node.service_configs.items(): data = service_config.data config = ServiceConfig( node_id=node.id, service=service_name, startup=data.startup, validate=data.validate, shutdown=data.shutdown, files=data.configs, directories=data.dirs, ) service_configuration(session, config) for file_name, file_data in service_config.files.items(): session.services.set_service_file(node.id, service_name, file_name, file_data) if node.config_service_configs: if not isinstance(core_node, CoreNode): context.abort( grpc.StatusCode.INVALID_ARGUMENT, "invalid node type with config service configs", ) for service_name, service_config in node.config_service_configs.items( ): service = core_node.config_services[service_name] if service_config.config: service.set_config(service_config.config) for name, template in service_config.templates.items(): service.set_template(name, template)
def UpdateItemStatus(self, request: models.UpdateItemStatusRequest, context: grpc.ServicerContext) -> models.Item: list_name = request.list_name items = self.todo_lists.get(list_name, None) if items is None: context.abort(grpc.StatusCode.INVALID_ARGUMENT, "List does not exist") item_id = request.item.item_id if items.get(item_id, None) is None: context.abort(grpc.StatusCode.INVALID_ARGUMENT, "Item does not exist") updated_item = grpc_to_item(request.item) items[item_id] = updated_item return models.Item( item_id=str(updated_item.id), value=updated_item.value, status=updated_item.status, )
def ModifyAckDeadline( self, request: pubsub_pb2.ModifyAckDeadlineRequest, context: grpc.ServicerContext, ) -> empty_pb2.Empty: # noqa: D403 """ModifyAckDeadline implementation.""" self.logger.debug("ModifyAckDeadline(%s)", LazyFormat(request)) try: subscription = self.subscriptions[request.subscription] except KeyError: context.abort(grpc.StatusCode.NOT_FOUND, "Subscription not found") # deadline is not tracked so only handle expiration when set to 0 if request.ack_deadline_seconds == 0: for ack_id in request.ack_ids: try: # move message from pulled back to published subscription.published.append(subscription.pulled.pop(ack_id)) except KeyError: context.abort(grpc.StatusCode.NOT_FOUND, "Ack ID not found") return empty_pb2.Empty()
def Get(self, request: GetAccountRequest, context: grpc.ServicerContext) -> ApiResponse: '''handle get request''' logger.info("getting account") active_global_tracer_span(context.get_active_span()) client = self._get_cos_client() command = GetStateRequest() command.entity_id=request.account_id try: result = client.GetState(command) validate(result.HasField("state"), "state was none", StatusCode.NOT_FOUND)(context) state = self._cos_unpack_state(result.state) validate(state is not None, "state was none", StatusCode.NOT_FOUND)(context) except grpc.RpcError as e: if e.code() == StatusCode.NOT_FOUND: context.abort(code=e.code(), details=e.details()) else: context.abort(code=StatusCode.INTERNAL, details=e.details()) return ApiResponse(account=state)