예제 #1
0
    def Tx(self, request_iterator: Iterator[TxRequest],
           context: grpc.ServicerContext) -> Iterator[TxResponse]:
        self._requests = []

        for request in request_iterator:
            print(f"REQUEST: {request}")
            self._requests.append(request)

            for mock_response in self._responses:
                tx_response = mock_response.test(request)
                if tx_response is not None:
                    print(f"RESPONSE: {tx_response}")
                    self._responses.remove(mock_response)

                    if isinstance(tx_response, TxResponse):
                        yield tx_response
                    else:
                        # return an error message
                        context.set_trailing_metadata([("ErrorType",
                                                        error_type)])
                        context.abort(grpc.StatusCode.UNKNOWN, tx_response)

                    break
            else:
                print(f"RESPONSE: {DONE}")
                yield DONE
예제 #2
0
 def intercept(self, method: Callable, request: Any,
               context: grpc.ServicerContext, method_name: str) -> Any:
     """Override this method to implement a custom interceptor.
      You should call method(request, context) to invoke the
      next handler (either the RPC method implementation, or the
      next interceptor in the list).
      Args:
          method: The next interceptor, or method implementation.
          request: The RPC request, as a protobuf message.
          context: The ServicerContext pass by gRPC to the service.
          method_name: A string of the form
              "/protobuf.package.Service/Method"
      Returns:
          This should generally return the result of
          method(request, context), which is typically the RPC
          method response, as a protobuf message. The interceptor
          is free to modify this in some way, however.
      """
     try:
         ip = context.peer().split(':')[1]
         print(
             f'RPC: {method_name} -- {ip} -- {datetime.now().strftime(DATE_FORMAT)}'
         )
         return method(request, context)
     except GrpcException as e:
         print(
             f'RPC: {method_name} -- {ip} -- {datetime.now().strftime(DATE_FORMAT)} -- {request}'
         )
         context.abort(e.status_code, e.details)
         raise
     except Exception as e:
         template = "An exception of type {0} occurred. Arguments:\n{1!r}"
         message = template.format(type(e).__name__, e.args)
         print(message)
         raise
        def validate_wrapper(self, request: Message,
                             context: grpc.ServicerContext):
            errors = []

            for field_name in field_names:
                field_validators: List[AbstractArgumentValidator] = []
                is_optional = (field_name in optional_non_empty_value
                               or field_name in optional_uuids_value
                               or field_name in optional_non_default_value
                               or field_name in optional_validators_value)
                if field_name in uuids_value + optional_uuids_value:
                    field_validators.append(UUIDBytesValidator())
                if field_name in non_empty_value + optional_non_empty_value:
                    field_validators.append(NonEmptyValidator())
                if field_name in non_default_value + optional_non_default_value:
                    field_validators.append(NonDefaultValidator())
                if field_name in itertools.chain(
                        validators_value.keys(),
                        optional_validators_value.keys()):
                    validator = {
                        **validators_value,
                        **optional_validators_value
                    }.get(field_name)
                    if validator is not None:
                        field_validators.append(validator)

                errors.extend(
                    _recurse_validate(request,
                                      name=field_name,
                                      validators=field_validators,
                                      is_optional=is_optional))
            if len(errors) > 0:
                context.abort(grpc.StatusCode.INVALID_ARGUMENT,
                              ", ".join(errors)[:1000])
            return func(self, request, context)
예제 #4
0
def require_all(
    attrs: t.Collection[str],
    obj: object,
    ctx: grpc.ServicerContext,
    parent: str = "request",
) -> None:
    """Verify that all required arguments are supplied by the client.

    If arguments are missing, the context will be aborted

    Args:
        attrs: Names of the required parameters.
        obj: Current request message (e.g., a subclass of google.protobuf.message.Message).
        ctx: Current gRPC context.
        parent: Name of the parent message. Only used to compose more helpful error.
    """
    func = attrgetter(*attrs)
    attr_result = func(obj)

    if len(attrs) == 1:
        attr_result = [attr_result]

    if not all(attr_result):
        ctx.abort(
            grpc.StatusCode.INVALID_ARGUMENT,
            f"The message '{parent}' requires the following attributes: {attrs}.",
        )
예제 #5
0
def forbid_all(
    attrs: t.Collection[str],
    obj: object,
    ctx: grpc.ServicerContext,
    parent: str = "request",
) -> None:
    """Verify that no illegal combination of arguments is provided by the client.

    Args:
        attrs: Names of parameters which cannot occur at the same time.
        obj: Current request message (e.g., a subclass of google.protobuf.message.Message).
        ctx: Current gRPC context.
        parent: Name of the parent message. Only used to compose more helpful error.
    """
    func = attrgetter(*attrs)
    attr_result = func(obj)

    if len(attrs) == 1:
        attr_result = [attr_result]

    if all(attr_result):
        ctx.abort(
            grpc.StatusCode.INVALID_ARGUMENT,
            f"The message '{parent}' is not allowed to allowed to have the following parameter combination: {attrs}.",
        )
예제 #6
0
def context_abort_with_exception_traceback(context: grpc.ServicerContext,
                                           exc: Exception,
                                           status_code: grpc.StatusCode):
    context.abort(
        code=status_code,
        details=(f'Exception Type: {type(exc)} \n'
                 f'Exception Message: {exc} \n'
                 f'Traceback: \n {traceback.format_tb(exc.__traceback__)}'))
예제 #7
0
 def CreateTopic(
     self, request: pubsub_pb2.Topic, context: grpc.ServicerContext
 ):  # noqa: D403
     """CreateTopic implementation."""
     self.logger.debug("CreateTopic(%s)", LazyFormat(request))
     if request.name in self.topics:
         context.abort(grpc.StatusCode.ALREADY_EXISTS, "Topic already exists")
     self.topics[request.name] = set()
     return request
예제 #8
0
 def DeleteTopic(self, request: pubsub_pb2.DeleteTopicRequest,
                 context: grpc.ServicerContext):  # noqa: D403
     """DeleteTopic implementation."""
     self.logger.debug("DeleteTopic(%s)", LazyFormat(request))
     try:
         self.topics.pop(request.topic)
     except KeyError:
         context.abort(grpc.StatusCode.NOT_FOUND, "Topic not found")
     return empty_pb2.Empty()
예제 #9
0
    def Query(self, query: warehouse_pb2.ProductQuery,
              context: grpc.ServicerContext) -> warehouse_pb2.Product:
        logger.info(f"Query product id={query.id}")
        product = self._products.query(product_id=query.id)
        if not product:
            not_found_message = f"Product: {query.id} not found!"
            logger.info(not_found_message)
            context.abort(grpc.StatusCode.NOT_FOUND, not_found_message)

        return self._products.query(product_id=query.id)
예제 #10
0
def get_mobility_node(session: Session, node_id: int,
                      context: ServicerContext) -> Union[WlanNode, EmaneNet]:
    try:
        return session.get_node(node_id, WlanNode)
    except CoreError:
        try:
            return session.get_node(node_id, EmaneNet)
        except CoreError:
            context.abort(grpc.StatusCode.NOT_FOUND,
                          "node id is not for wlan or emane")
예제 #11
0
def handle_except(ex: Exception, ctx: grpc.ServicerContext) -> None:
    """Handler that can be called when handling an exception.

    It will pass the traceback to the gRPC client and abort the context.

    Args:
        ex: Exception that occured.
        ctx: Current gRPC context.
    """

    msg = "".join(traceback.TracebackException.from_exception(ex).format())
    ctx.abort(grpc.StatusCode.UNKNOWN, msg)
예제 #12
0
 def CreateSubscription(
     self, request: pubsub_pb2.Subscription, context: grpc.ServicerContext
 ):  # noqa: D403
     """CreateSubscription implementation."""
     self.logger.debug("CreateSubscription(%s)", LazyFormat(request))
     if request.name in self.subscriptions:
         context.abort(grpc.StatusCode.ALREADY_EXISTS, "Subscription already exists")
     elif request.topic not in self.topics:
         context.abort(grpc.StatusCode.NOT_FOUND, "Topic not found")
     subscription = Subscription()
     self.subscriptions[request.name] = subscription
     self.topics[request.topic].add(subscription)
     return request
예제 #13
0
파일: server.py 프로젝트: Johnxjp/grpc_todo
    def AddToList(self, request: models.AddToListRequest,
                  context: grpc.ServicerContext) -> models.Item:
        items = self.todo_lists.get(request.list_name, None)
        if items is None:
            context.abort(grpc.StatusCode.INVALID_ARGUMENT,
                          "List does not exist")

        item_id = uuid4()
        item = Item(item_id, request.value, ItemStatus.COMPLETE)
        self.todo_lists[request.list_name][item_id] = item
        return models.Item(item_id=str(item.id),
                           value=item.value,
                           status=item.status.value)
예제 #14
0
파일: server.py 프로젝트: Johnxjp/grpc_todo
    def CreateList(self, request: models.CreateListRequest,
                   context: grpc.ServicerContext) -> models.Empty:
        name = request.name
        if not name:
            context.abort(grpc.StatusCode.INVALID_ARGUMENT,
                          "List name cannot be empty")

        if name in self.todo_lists:
            context.abort(grpc.StatusCode.ALREADY_EXISTS,
                          "List already exists.")

        self.todo_lists[name] = {}
        return models.Empty()
예제 #15
0
    def wrapper(self, request, context: ServicerContext):
        metadata = context.invocation_metadata()
        token = next((x for x in metadata if x[0] == config.token_header_key),
                     None)
        if token is None:
            context.abort(StatusCode.UNAUTHENTICATED, "Permission denied")

        token = token.value.split(token_header)[1]
        claims = get_token_claims(token)
        if claims is None:
            context.abort(StatusCode.UNAUTHENTICATED, "Permission denied")
        else:
            return func(self, request, context, claims=claims)
예제 #16
0
 def wrapped_catch_them_all(self, request,
                            context: grpc.ServicerContext):
     try:
         return func(self, request, context)
     except Exception as e:
         start_time = getattr(request, "start_time", None)
         if start_time is not None:
             delta = time.perf_counter() - start_time
             self._log.exception("FAIL %.3f", delta)
         else:
             self._log.exception("FAIL ?")
         context.abort(grpc.StatusCode.INTERNAL,
                       "%s: %s" % (type(e), e))
예제 #17
0
 def Acknowledge(self, request: pubsub_pb2.AcknowledgeRequest,
                 context: grpc.ServicerContext):
     """Acknowledge implementation."""
     self.logger.debug("Acknowledge(%s)", LazyFormat(request))
     try:
         subscription = self.subscriptions[request.subscription]
     except KeyError:
         context.abort(grpc.StatusCode.NOT_FOUND, "Subscription not found")
     for ack_id in request.ack_ids:
         try:
             subscription.pulled.pop(ack_id)
         except KeyError:
             context.abort(grpc.StatusCode.NOT_FOUND, "Ack ID not found")
     return empty_pb2.Empty()
예제 #18
0
 def permission_denied(self,
                       request,
                       context: grpc.ServicerContext,
                       message=None,
                       code=None):
     """
     If request is not permitted, determine what kind of exception to raise.
     """
     if request.authenticators and not request.successful_authenticator:
         context.abort(grpc.StatusCode.UNAUTHENTICATED,
                       details=message if message else '')
         # raise exceptions.NotAuthenticated()
     context.abort(grpc.StatusCode.PERMISSION_DENIED,
                   details=message if message else '')
예제 #19
0
 def DeleteSubscription(
     self,
     request: pubsub_pb2.DeleteSubscriptionRequest,
     context: grpc.ServicerContext,
 ):  # noqa: D403
     """DeleteSubscription implementation."""
     self.logger.debug("DeleteSubscription(%s)", LazyFormat(request))
     try:
         subscription = self.subscriptions.pop(request.subscription)
     except KeyError:
         context.abort(grpc.StatusCode.NOT_FOUND, "Subscription not found")
     for subscriptions in self.topics.values():
         subscriptions.discard(subscription)
     return empty_pb2.Empty()
예제 #20
0
    def _check_audio(samples: Sequence[float], context: grpc.ServicerContext):
        """
        Check provided audio samples for validity.

        Returns gRPC errors (using ``context``) if not.

        :param samples: audio samples.
        :param context: gRPC context.
        """
        if not len(samples):
            context.abort(grpc.StatusCode.INVALID_ARGUMENT,
                          "no audio samples provided")

        if len(samples) > 16000:
            context.abort(grpc.StatusCode.INVALID_ARGUMENT,
                          "too many audio samples provided")
예제 #21
0
    def UpdateTopic(
        self, request: pubsub_pb2.UpdateTopicRequest, context: grpc.ServicerContext
    ):
        """Repurpose UpdateTopic API for setting up test conditions.

        :param request.topic.name: Name of the topic that needs overrides.
        :param request.update_mask.paths: A list of overrides, of the form
        "key=value".

        Valid override keys are "status_code" and "sleep". An override value of
        "" disables the override.

        For the override key "status_code" the override value indicates the
        status code that should be returned with an empty response by Publish
        requests, and non-empty override values must be a property of
        `grpc.StatusCode` such as "UNIMPLEMENTED".

        For the override key "sleep" the override value indicates a number of
        seconds Publish requests should sleep before returning, and non-empty
        override values must be a valid float. Publish requests will return
        a valid response without recording messages.
        """
        self.logger.debug("UpdateTopic(%s)", LazyFormat(request))
        for override in request.update_mask.paths:
            key, value = override.split("=", 1)
            if key.lower() in ("status_code", "statuscode"):
                if value:
                    try:
                        self.status_codes[request.topic.name] = getattr(
                            grpc.StatusCode, value.upper()
                        )
                    except AttributeError:
                        context.abort(
                            grpc.StatusCode.INVALID_ARGUMENT, "Invalid status code"
                        )
                else:
                    try:
                        del self.status_codes[request.topic.name]
                    except KeyError:
                        context.abort(
                            grpc.StatusCode.NOT_FOUND, "Status code override not found"
                        )
            elif key.lower() == "sleep":
                if value:
                    try:
                        self.sleep = float(value)
                    except ValueError:
                        context.abort(
                            grpc.StatusCode.INVALID_ARGUMENT, "Invalid sleep time"
                        )
                else:
                    self.sleep = None
            else:
                context.abort(grpc.StatusCode.Not_FOUND, "Path not found")
        return request.topic
예제 #22
0
파일: grpcutils.py 프로젝트: walterhil/core
def get_nem_id(node: CoreNode, netif_id: int, context: ServicerContext) -> int:
    """
    Get nem id for a given node and interface id.

    :param node: node to get nem id for
    :param netif_id: id of interface on node to get nem id for
    :param context: request context
    :return: nem id
    """
    netif = node.netif(netif_id)
    if not netif:
        message = f"{node.name} missing interface {netif_id}"
        context.abort(grpc.StatusCode.NOT_FOUND, message)
    net = netif.net
    if not isinstance(net, EmaneNet):
        message = f"{node.name} interface {netif_id} is not an EMANE network"
        context.abort(grpc.StatusCode.INVALID_ARGUMENT, message)
    return net.getnemid(netif)
예제 #23
0
 def Put(self, request_iterator: connectme_pb2.FileChunk,
         context: grpc.ServicerContext):
     """Write/create the incoming files"""
     total_files: int = 0
     total_bytes: int = 0
     try:
         # paths are with respect to server side fs
         (total_files,
          total_bytes) = self.fileChunkReceiver(request_iterator, False)
     except IsADirectoryError as e:
         details = "Tried to write to directory: \"{}\"".format(e)
         logging.warn(details)
         context.abort(grpc.StatusCode.INVALID_ARGUMENT, details)
     except Exception as e:
         details = "Unknown error on open or write: \"{}\"".format(e)
         logging.warn(details)
         context.abort(grpc.StatusCode.UNKNOWN, details)
     return connectme_pb2.PutReturn(total_files=total_files,
                                    total_bytes=total_bytes)
예제 #24
0
    def get(self, request: BookGetRequest,
            context: grpc.ServicerContext) -> Book:
        LOG.info("BookRpc get method",
                 book_id=request.id if hasattr(request, 'id') else None)
        if request.id not in self.state:
            return context.abort(
                grpc.StatusCode.INVALID_ARGUMENT,
                f"Book with id [{request.id}] does not exists!")

        return BookGetReply(book=self.state[request.id])
예제 #25
0
 def Publish(self, request: pubsub_pb2.PublishRequest,
             context: grpc.ServicerContext):
     """Publish implementation."""
     self.logger.debug("Publish(%.100s)", LazyFormat(request))
     if request.topic in self.status_codes:
         context.abort(self.status_codes[request.topic], "Override")
     if self.sleep is not None:
         time.sleep(self.sleep)
     message_ids: List[str] = []
     try:
         subscriptions = self.topics[request.topic]
     except KeyError:
         context.abort(grpc.StatusCode.NOT_FOUND, "Topic not found")
     message_ids = [uuid.uuid4().hex for _ in request.messages]
     for _id, message in zip(message_ids, request.messages):
         message.message_id = _id
     for subscription in subscriptions:
         subscription.published.extend(request.messages)
     return pubsub_pb2.PublishResponse(message_ids=message_ids)
예제 #26
0
 def Get(self, request_iterator: connectme_pb2.FilePath,
         context: grpc.ServicerContext):
     """Fetch the incoming file patterns"""
     for file in request_iterator:
         try:
             # paths are with respect to server side fs
             paths = self.expandPath(file.path)
             if len(paths) == 0:
                 raise FileNotFoundError(file.path)
             for c in self.fileChunkGenerator(paths, False):
                 yield c
         except FileNotFoundError as e:
             details = "File not found: \"{}\"".format(e)
             logging.warn(details)
             context.abort(grpc.StatusCode.NOT_FOUND, details)
         except Exception as e:
             details = "Unknown error on open or read: \"{}\"".format(e)
             logging.warn(details)
             context.abort(grpc.StatusCode.UNKNOWN, details)
예제 #27
0
def configure_node(session: Session, node: core_pb2.Node, core_node: NodeBase,
                   context: ServicerContext) -> None:
    for emane_config in node.emane_configs:
        _id = utils.iface_config_id(node.id, emane_config.iface_id)
        config = {k: v.value for k, v in emane_config.config.items()}
        session.emane.set_config(_id, emane_config.model, config)
    if node.wlan_config:
        config = {k: v.value for k, v in node.wlan_config.items()}
        session.mobility.set_model_config(node.id, BasicRangeModel.name,
                                          config)
    if node.mobility_config:
        config = {k: v.value for k, v in node.mobility_config.items()}
        session.mobility.set_model_config(node.id, Ns2ScriptedMobility.name,
                                          config)
    for service_name, service_config in node.service_configs.items():
        data = service_config.data
        config = ServiceConfig(
            node_id=node.id,
            service=service_name,
            startup=data.startup,
            validate=data.validate,
            shutdown=data.shutdown,
            files=data.configs,
            directories=data.dirs,
        )
        service_configuration(session, config)
        for file_name, file_data in service_config.files.items():
            session.services.set_service_file(node.id, service_name, file_name,
                                              file_data)
    if node.config_service_configs:
        if not isinstance(core_node, CoreNode):
            context.abort(
                grpc.StatusCode.INVALID_ARGUMENT,
                "invalid node type with config service configs",
            )
        for service_name, service_config in node.config_service_configs.items(
        ):
            service = core_node.config_services[service_name]
            if service_config.config:
                service.set_config(service_config.config)
            for name, template in service_config.templates.items():
                service.set_template(name, template)
예제 #28
0
파일: server.py 프로젝트: Johnxjp/grpc_todo
    def UpdateItemStatus(self, request: models.UpdateItemStatusRequest,
                         context: grpc.ServicerContext) -> models.Item:
        list_name = request.list_name
        items = self.todo_lists.get(list_name, None)
        if items is None:
            context.abort(grpc.StatusCode.INVALID_ARGUMENT,
                          "List does not exist")

        item_id = request.item.item_id
        if items.get(item_id, None) is None:
            context.abort(grpc.StatusCode.INVALID_ARGUMENT,
                          "Item does not exist")

        updated_item = grpc_to_item(request.item)
        items[item_id] = updated_item
        return models.Item(
            item_id=str(updated_item.id),
            value=updated_item.value,
            status=updated_item.status,
        )
예제 #29
0
 def ModifyAckDeadline(
     self,
     request: pubsub_pb2.ModifyAckDeadlineRequest,
     context: grpc.ServicerContext,
 ) -> empty_pb2.Empty:  # noqa: D403
     """ModifyAckDeadline implementation."""
     self.logger.debug("ModifyAckDeadline(%s)", LazyFormat(request))
     try:
         subscription = self.subscriptions[request.subscription]
     except KeyError:
         context.abort(grpc.StatusCode.NOT_FOUND, "Subscription not found")
     # deadline is not tracked so only handle expiration when set to 0
     if request.ack_deadline_seconds == 0:
         for ack_id in request.ack_ids:
             try:
                 # move message from pulled back to published
                 subscription.published.append(subscription.pulled.pop(ack_id))
             except KeyError:
                 context.abort(grpc.StatusCode.NOT_FOUND, "Ack ID not found")
     return empty_pb2.Empty()
예제 #30
0
    def Get(self, request: GetAccountRequest, context: grpc.ServicerContext) -> ApiResponse:
        '''handle get request'''
        logger.info("getting account")
        active_global_tracer_span(context.get_active_span())
        client = self._get_cos_client()
        command = GetStateRequest()
        command.entity_id=request.account_id

        try:
            result = client.GetState(command)
            validate(result.HasField("state"), "state was none", StatusCode.NOT_FOUND)(context)
            state = self._cos_unpack_state(result.state)
            validate(state is not None, "state was none", StatusCode.NOT_FOUND)(context)

        except grpc.RpcError as e:
            if e.code() == StatusCode.NOT_FOUND:
                context.abort(code=e.code(), details=e.details())
            else:
                context.abort(code=StatusCode.INTERNAL, details=e.details())

        return ApiResponse(account=state)