async def rpc_ping(self, request: dht_pb2.PingRequest, context: grpc.ServicerContext): """ Some node wants us to add it to our routing table. """ response = dht_pb2.PingResponse(peer=self.node_info, sender_endpoint=context.peer(), dht_time=get_dht_time(), available=False) if request.peer and request.peer.node_id and request.peer.rpc_port: sender_id = DHTID.from_bytes(request.peer.node_id) if request.peer.endpoint != dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value: sender_endpoint = request.peer.endpoint # if peer has preferred endpoint, use it else: sender_endpoint = replace_port(context.peer(), new_port=request.peer.rpc_port) response.sender_endpoint = sender_endpoint if request.validate: response.available = await self.call_ping( response.sender_endpoint, validate=False) == sender_id asyncio.create_task( self.update_routing_table(sender_id, sender_endpoint, responded=response.available or not request.validate)) return response
def Join( # pylint: disable=invalid-name self, request_iterator: Iterator[ClientMessage], context: grpc.ServicerContext, ) -> Iterator[ServerMessage]: """Method will be invoked by each GrpcClientProxy which participates in the network. Protocol: - The first message is sent from the server to the client - Both ServerMessage and ClientMessage are message "wrappers" wrapping the actual message - The Join method is (pretty much) protocol unaware """ peer = context.peer() bridge = self.grpc_bridge_factory() client = self.client_factory(peer, bridge) is_success = register_client(self.client_manager, client, context) if is_success: # Get iterators client_message_iterator = request_iterator server_message_iterator = bridge.server_message_iterator() # All messages will be pushed to client bridge directly while True: try: # Get server message from bridge and yield it server_message = next(server_message_iterator) yield server_message # Wait for client message client_message = next(client_message_iterator) bridge.set_client_message(client_message) except StopIteration: break
def Heartbeat( self, request: coordinator_pb2.HeartbeatRequest, context: grpc.ServicerContext) -> coordinator_pb2.HeartbeatReply: """The Heartbeat gRPC method. Participants periodically send an heartbeat so that the :class:`Coordinator` can detect failures. Args: request (:class:`~.coordinator_pb2.HeartbeatRequest`): The participant's request. The participant's request contains the current :class:`~.coordinator_pb2.State` and round number the participant is on. context (:class:`~grpc.ServicerContext`): The context associated with the gRPC request. Returns: :class:`~.coordinator_pb2.HeartbeatReply`: The reply to the participant's request. The reply contains both the :class:`~.coordinator_pb2.State` and the current round the coordinator is on. If a training session has not started yet the round number defaults to 0. """ try: return self.coordinator.on_message(request, context.peer()) except UnknownParticipantError as error: context.set_details(str(error)) context.set_code(grpc.StatusCode.PERMISSION_DENIED) return coordinator_pb2.HeartbeatReply()
def StartTraining( self, request: coordinator_pb2.StartTrainingRequest, context: grpc.ServicerContext, ) -> coordinator_pb2.StartTrainingReply: """The StartTraining gRPC method. Once a participant is notified that the :class:`xain_fl.coordinator.coordinator.Coordinator` is in a round (through the state advertised in the :class:`~.coordinator_pb2.HeartbeatReply`), the participant should call this method in order to get the global model weights in order to start the training for the round. Args: request (:class:`~.coordinator_pb2.StartTrainingRequest`): The participant's request. context (:class:`~grpc.ServicerContext`): The context associated with the gRPC request. Returns: :class:`~.coordinator_pb2.StartTrainingReply`: The reply to the participant's request. The reply contains the global model weights. """ try: return self.coordinator.on_message(request, context.peer()) except UnknownParticipantError as error: context.set_details(str(error)) context.set_code(grpc.StatusCode.PERMISSION_DENIED) return coordinator_pb2.StartTrainingReply() except InvalidRequestError as error: context.set_details(str(error)) context.set_Code(grpc.StatusCode.FAILED_PRECONDITION) return coordinator_pb2.StartTrainingReply()
def EndTraining( self, request: coordinator_pb2.EndTrainingRequest, context: grpc.ServicerContext) -> coordinator_pb2.EndTrainingReply: """The EndTraining gRPC method. Once a participant has finished the training for the round it calls this method in order to submit to the :class:`xain_fl.coordinator.coordinator.Coordinator` the updated weights. Args: request (:class:`~.coordinator_pb2.EndTrainingRequest`): The participant's request. The request contains the updated weights as a result of the training as well as any metrics helpful for the :class:`xain_fl.coordinator.coordinator.Coordinator`. context (:class:`~grpc.ServicerContext`): The context associated with the gRPC request. Returns: :class:`~.coordinator_pb2.EndTrainingReply`: The reply to the participant's request. The reply is just an acknowledgment that the :class:`xain_fl.coordinator.coordinator.Coordinator` successfully received the updated weights. """ try: return self.coordinator.on_message(request, context.peer()) except DuplicatedUpdateError as error: context.set_details(str(error)) context.set_code(grpc.StatusCode.ALREADY_EXISTS) return coordinator_pb2.EndTrainingReply() except UnknownParticipantError as error: context.set_details(str(error)) context.set_code(grpc.StatusCode.PERMISSION_DENIED) return coordinator_pb2.EndTrainingReply()
async def Infer(self, request: InferenceRequest, context: grpc.ServicerContext) -> InferenceResponse: """ Perform inference on the submitted audio data. :param request: inference request. :param context: RPC context. :return: inference response. """ self._check_audio(request.audio_samples, context) peer = context.peer() uid = uuid.uuid4() logger.info(f"inferring sample from {peer}: UUID {uid}") samples = np.array(request.audio_samples, dtype=np.float32) start = time.monotonic_ns() label = await self._inferer.infer( np.array(request.audio_samples, dtype=np.float32)) elapsed_ns = time.monotonic_ns() - start logger.info( f"inferred {peer}'s sample in {elapsed_ns}ns as having label {label}" ) wavfile.write( self._config.infer_upload_path.joinpath( f"{uid}-{label.to_text()}.wav"), 16000, samples, ) return InferenceResponse(label=label.value)
def intercept(self, method: Callable, request: Any, context: grpc.ServicerContext, method_name: str) -> Any: """Override this method to implement a custom interceptor. You should call method(request, context) to invoke the next handler (either the RPC method implementation, or the next interceptor in the list). Args: method: The next interceptor, or method implementation. request: The RPC request, as a protobuf message. context: The ServicerContext pass by gRPC to the service. method_name: A string of the form "/protobuf.package.Service/Method" Returns: This should generally return the result of method(request, context), which is typically the RPC method response, as a protobuf message. The interceptor is free to modify this in some way, however. """ try: ip = context.peer().split(':')[1] print( f'RPC: {method_name} -- {ip} -- {datetime.now().strftime(DATE_FORMAT)}' ) return method(request, context) except GrpcException as e: print( f'RPC: {method_name} -- {ip} -- {datetime.now().strftime(DATE_FORMAT)} -- {request}' ) context.abort(e.status_code, e.details) raise except Exception as e: template = "An exception of type {0} occurred. Arguments:\n{1!r}" message = template.format(type(e).__name__, e.args) print(message) raise
async def rpc_ping(self, peer_info: dht_pb2.NodeInfo, context: grpc.ServicerContext): """ Some node wants us to add it to our routing table. """ if peer_info.node_id and peer_info.rpc_port: sender_id = DHTID.from_bytes(peer_info.node_id) rpc_endpoint = replace_port(context.peer(), new_port=peer_info.rpc_port) asyncio.create_task( self.update_routing_table(sender_id, rpc_endpoint)) return self.node_info
def wrapped_set_logging_context(self, request, context: grpc.ServicerContext): obj = request_log_context_extractors[type(request)](request) meta = {} for md in context.invocation_metadata(): meta[md.key] = md.value obj["meta"] = meta obj["peer"] = context.peer() slogging.set_context(obj) self._log.info("new %s", type(request).__name__) return func(self, request, context)
def Connect(self, request: mw_protocols.ClientConnectRequest, context: grpc.ServicerContext) \ -> Generator[mw_protocols.Command, None, None]: """ Called by a client when it connects to the Avocado framework. Preserves a channel with client for the whole time when client is accessible. Note that the connection may break from client side. :return: stream of commands """ def set_dependency_command(): command = mw_protocols.Command() command.SET_DEPENDENCY.name = dependency command.SET_DEPENDENCY.ip = ip logging.info( f"Sent address of dependency {dependency} to client " f"{request.application}:{request.clientType}:{id_}. Value = {ip}" ) client.last_call = time.perf_counter() return command # Get IP from request peer: str = context.peer() client_ip: str = self._parse_ip_from_peer(peer) # Prepare the client descriptor: client_descriptor = protocols.ClientDescriptor( application=request.application, type=request.clientType, ip=client_ip) if request.establishedConnection: client_descriptor.hasID = True client_descriptor.persistent_id = request.id rc, id_, client = self._client_model.assign_client( client_descriptor, context) if rc == mw_protocols.ClientResponseCode.Value("OK"): if not request.establishedConnection: yield mw_protocols.Command(SET_ID=id_) else: for dependency, ip in client.dependencies.items(): yield set_dependency_command() while True: try: while len(client.dependency_updates) == 0: time.sleep(self._client_model.wait_signal_frequency) client.last_call = time.perf_counter() yield mw_protocols.Command(WAIT=0) dependency, ip = client.dependency_updates.popitem() yield set_dependency_command() except Exception: context.cancel() else: yield mw_protocols.Command(ERROR=rc) return
def dispatch(self, request_object: Message, context_object: grpc.ServicerContext, func: Callable): ctx = contextvars.copy_context() context.set(context_object) bound_logger = self.logger.bind(peer=context_object.peer()) with bound_logger.catch(): try: return ctx.run(func, request_object) except exceptions.GRPCError as e: context_object.set_code(e.code) if e.details: context_object.set_details(e.details) except NotImplementedError: context_object.set_code(grpc.StatusCode.UNIMPLEMENTED)
async def Train(self, request: TrainingRequest, context: grpc.ServicerContext) -> Empty: """ Re-train the model on the submitted audio data. :param request: training request. :param context: RPC context. :return: empty response. """ self._check_audio(request.audio_samples, context) peer = context.peer() uid = uuid.uuid4() label = Labels(request.label) logger.info( f"saving training sample from {peer}: UUID {uid}, label {label}") samples = np.array(request.audio_samples, dtype=np.float32) destination_directory = self._config.training_upload_path.joinpath( label.to_text()) destination_directory.mkdir(mode=0o770, parents=True, exist_ok=True) wavfile.write( destination_directory.joinpath(f"{uid}.wav"), 16000, (samples * 32767.0).astype(np.int16), ) untrained_samples = self._persistent_state.increment_untrained_samples( 1) logger.info( f"server has collected {untrained_samples} untrained samples") if untrained_samples >= self._config.samples_before_train: logger.info( f"collected a sufficient amount of uploaded training samples: starting training" ) if self._train_and_swap(): logger.info(f"started training and swapping task") else: logger.info(f"training and swapping ongoing, skipping") return Empty()
def Rendezvous( self, request: coordinator_pb2.RendezvousRequest, context: grpc.ServicerContext) -> coordinator_pb2.RendezvousReply: """The Rendezvous gRPC method. A participant contacts the coordinator and the coordinator adds the participant to its list of participants. If the coordinator already has all the participants it tells the participant to try again later. Args: request (:class:`~.coordinator_pb2.RendezvousRequest`): The participant's request. context (:class:`~grpc.ServicerContext`): The context associated with the gRPC request. Returns: :class:`~.coordinator_pb2.RendezvousReply`: The reply to the participant's request. The reply is an enum containing either: ACCEPT: If the :class:`xain_fl.coordinator.coordinator.Coordinator` does not have enough participants. LATER: If the :class:`xain_fl.coordinator.coordinator.Coordinator` already has enough participants. """ return self.coordinator.on_message(request, context.peer())
def SomeReply(self, request, context: grpc.ServicerContext): print(context.peer(), flush=True) time.sleep(5) return common_pb2.SomeMessage(msg='bar')