Beispiel #1
0
    async def rpc_ping(self, request: dht_pb2.PingRequest,
                       context: grpc.ServicerContext):
        """ Some node wants us to add it to our routing table. """
        response = dht_pb2.PingResponse(peer=self.node_info,
                                        sender_endpoint=context.peer(),
                                        dht_time=get_dht_time(),
                                        available=False)

        if request.peer and request.peer.node_id and request.peer.rpc_port:
            sender_id = DHTID.from_bytes(request.peer.node_id)
            if request.peer.endpoint != dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value:
                sender_endpoint = request.peer.endpoint  # if peer has preferred endpoint, use it
            else:
                sender_endpoint = replace_port(context.peer(),
                                               new_port=request.peer.rpc_port)

            response.sender_endpoint = sender_endpoint
            if request.validate:
                response.available = await self.call_ping(
                    response.sender_endpoint, validate=False) == sender_id

            asyncio.create_task(
                self.update_routing_table(sender_id,
                                          sender_endpoint,
                                          responded=response.available
                                          or not request.validate))

        return response
    def Join(  # pylint: disable=invalid-name
        self, request_iterator: Iterator[ClientMessage], context: grpc.ServicerContext,
    ) -> Iterator[ServerMessage]:
        """Method will be invoked by each GrpcClientProxy which participates in the network.

        Protocol:
            - The first message is sent from the server to the client
            - Both ServerMessage and ClientMessage are message "wrappers"
                wrapping the actual message
            - The Join method is (pretty much) protocol unaware
        """
        peer = context.peer()
        bridge = self.grpc_bridge_factory()
        client = self.client_factory(peer, bridge)
        is_success = register_client(self.client_manager, client, context)

        if is_success:
            # Get iterators
            client_message_iterator = request_iterator
            server_message_iterator = bridge.server_message_iterator()

            # All messages will be pushed to client bridge directly
            while True:
                try:
                    # Get server message from bridge and yield it
                    server_message = next(server_message_iterator)
                    yield server_message
                    # Wait for client message
                    client_message = next(client_message_iterator)
                    bridge.set_client_message(client_message)
                except StopIteration:
                    break
Beispiel #3
0
    def Heartbeat(
            self, request: coordinator_pb2.HeartbeatRequest,
            context: grpc.ServicerContext) -> coordinator_pb2.HeartbeatReply:
        """The Heartbeat gRPC method.

        Participants periodically send an heartbeat so that the
        :class:`Coordinator` can detect failures.

        Args:
            request (:class:`~.coordinator_pb2.HeartbeatRequest`): The
                participant's request. The participant's request contains the
                current :class:`~.coordinator_pb2.State` and round number the
                participant is on.
            context (:class:`~grpc.ServicerContext`): The context associated
                with the gRPC request.

        Returns:
            :class:`~.coordinator_pb2.HeartbeatReply`: The reply to the
            participant's request. The reply contains both the
            :class:`~.coordinator_pb2.State` and the current round the
            coordinator is on. If a training session has not started yet the
            round number defaults to 0.
        """
        try:
            return self.coordinator.on_message(request, context.peer())
        except UnknownParticipantError as error:
            context.set_details(str(error))
            context.set_code(grpc.StatusCode.PERMISSION_DENIED)
            return coordinator_pb2.HeartbeatReply()
Beispiel #4
0
    def StartTraining(
        self,
        request: coordinator_pb2.StartTrainingRequest,
        context: grpc.ServicerContext,
    ) -> coordinator_pb2.StartTrainingReply:
        """The StartTraining gRPC method.

        Once a participant is notified that the :class:`xain_fl.coordinator.coordinator.Coordinator`
        is in a round (through the state advertised in the
        :class:`~.coordinator_pb2.HeartbeatReply`), the participant should call this
        method in order to get the global model weights in order to start the
        training for the round.

        Args:
            request (:class:`~.coordinator_pb2.StartTrainingRequest`): The participant's request.
            context (:class:`~grpc.ServicerContext`): The context associated with the gRPC request.

        Returns:
            :class:`~.coordinator_pb2.StartTrainingReply`: The reply to the
            participant's request. The reply contains the global model weights.
            """
        try:
            return self.coordinator.on_message(request, context.peer())
        except UnknownParticipantError as error:
            context.set_details(str(error))
            context.set_code(grpc.StatusCode.PERMISSION_DENIED)
            return coordinator_pb2.StartTrainingReply()
        except InvalidRequestError as error:
            context.set_details(str(error))
            context.set_Code(grpc.StatusCode.FAILED_PRECONDITION)
            return coordinator_pb2.StartTrainingReply()
Beispiel #5
0
    def EndTraining(
            self, request: coordinator_pb2.EndTrainingRequest,
            context: grpc.ServicerContext) -> coordinator_pb2.EndTrainingReply:
        """The EndTraining gRPC method.

        Once a participant has finished the training for the round it calls this
        method in order to submit to the :class:`xain_fl.coordinator.coordinator.Coordinator`
        the updated weights.

        Args:
            request (:class:`~.coordinator_pb2.EndTrainingRequest`): The
                participant's request. The request contains the updated weights as
                a result of the training as well as any metrics helpful for the
                :class:`xain_fl.coordinator.coordinator.Coordinator`.
            context (:class:`~grpc.ServicerContext`): The context associated with the gRPC request.

        Returns:
            :class:`~.coordinator_pb2.EndTrainingReply`: The reply to the
            participant's request. The reply is just an acknowledgment that
            the :class:`xain_fl.coordinator.coordinator.Coordinator` successfully received
            the updated weights.
        """
        try:
            return self.coordinator.on_message(request, context.peer())
        except DuplicatedUpdateError as error:
            context.set_details(str(error))
            context.set_code(grpc.StatusCode.ALREADY_EXISTS)
            return coordinator_pb2.EndTrainingReply()
        except UnknownParticipantError as error:
            context.set_details(str(error))
            context.set_code(grpc.StatusCode.PERMISSION_DENIED)
            return coordinator_pb2.EndTrainingReply()
Beispiel #6
0
    async def Infer(self, request: InferenceRequest,
                    context: grpc.ServicerContext) -> InferenceResponse:
        """
        Perform inference on the submitted audio data.

        :param request: inference request.
        :param context: RPC context.
        :return: inference response.
        """
        self._check_audio(request.audio_samples, context)
        peer = context.peer()
        uid = uuid.uuid4()

        logger.info(f"inferring sample from {peer}: UUID {uid}")

        samples = np.array(request.audio_samples, dtype=np.float32)

        start = time.monotonic_ns()
        label = await self._inferer.infer(
            np.array(request.audio_samples, dtype=np.float32))
        elapsed_ns = time.monotonic_ns() - start

        logger.info(
            f"inferred {peer}'s sample in {elapsed_ns}ns as having label {label}"
        )
        wavfile.write(
            self._config.infer_upload_path.joinpath(
                f"{uid}-{label.to_text()}.wav"),
            16000,
            samples,
        )

        return InferenceResponse(label=label.value)
Beispiel #7
0
 def intercept(self, method: Callable, request: Any,
               context: grpc.ServicerContext, method_name: str) -> Any:
     """Override this method to implement a custom interceptor.
      You should call method(request, context) to invoke the
      next handler (either the RPC method implementation, or the
      next interceptor in the list).
      Args:
          method: The next interceptor, or method implementation.
          request: The RPC request, as a protobuf message.
          context: The ServicerContext pass by gRPC to the service.
          method_name: A string of the form
              "/protobuf.package.Service/Method"
      Returns:
          This should generally return the result of
          method(request, context), which is typically the RPC
          method response, as a protobuf message. The interceptor
          is free to modify this in some way, however.
      """
     try:
         ip = context.peer().split(':')[1]
         print(
             f'RPC: {method_name} -- {ip} -- {datetime.now().strftime(DATE_FORMAT)}'
         )
         return method(request, context)
     except GrpcException as e:
         print(
             f'RPC: {method_name} -- {ip} -- {datetime.now().strftime(DATE_FORMAT)} -- {request}'
         )
         context.abort(e.status_code, e.details)
         raise
     except Exception as e:
         template = "An exception of type {0} occurred. Arguments:\n{1!r}"
         message = template.format(type(e).__name__, e.args)
         print(message)
         raise
Beispiel #8
0
 async def rpc_ping(self, peer_info: dht_pb2.NodeInfo,
                    context: grpc.ServicerContext):
     """ Some node wants us to add it to our routing table. """
     if peer_info.node_id and peer_info.rpc_port:
         sender_id = DHTID.from_bytes(peer_info.node_id)
         rpc_endpoint = replace_port(context.peer(),
                                     new_port=peer_info.rpc_port)
         asyncio.create_task(
             self.update_routing_table(sender_id, rpc_endpoint))
     return self.node_info
 def wrapped_set_logging_context(self, request, context: grpc.ServicerContext):
     obj = request_log_context_extractors[type(request)](request)
     meta = {}
     for md in context.invocation_metadata():
         meta[md.key] = md.value
     obj["meta"] = meta
     obj["peer"] = context.peer()
     slogging.set_context(obj)
     self._log.info("new %s", type(request).__name__)
     return func(self, request, context)
Beispiel #10
0
    def Connect(self, request: mw_protocols.ClientConnectRequest, context: grpc.ServicerContext) \
            -> Generator[mw_protocols.Command, None, None]:
        """
        Called by a client when it connects to the Avocado framework.
        Preserves a channel with client for the whole time when client is accessible.
        Note that the connection may break from client side.
        :return: stream of commands
        """
        def set_dependency_command():
            command = mw_protocols.Command()
            command.SET_DEPENDENCY.name = dependency
            command.SET_DEPENDENCY.ip = ip
            logging.info(
                f"Sent address of dependency {dependency} to client "
                f"{request.application}:{request.clientType}:{id_}. Value = {ip}"
            )
            client.last_call = time.perf_counter()
            return command

        # Get IP from request
        peer: str = context.peer()
        client_ip: str = self._parse_ip_from_peer(peer)

        # Prepare the client descriptor:
        client_descriptor = protocols.ClientDescriptor(
            application=request.application,
            type=request.clientType,
            ip=client_ip)
        if request.establishedConnection:
            client_descriptor.hasID = True
            client_descriptor.persistent_id = request.id

        rc, id_, client = self._client_model.assign_client(
            client_descriptor, context)
        if rc == mw_protocols.ClientResponseCode.Value("OK"):
            if not request.establishedConnection:
                yield mw_protocols.Command(SET_ID=id_)
            else:
                for dependency, ip in client.dependencies.items():
                    yield set_dependency_command()
            while True:
                try:
                    while len(client.dependency_updates) == 0:
                        time.sleep(self._client_model.wait_signal_frequency)
                        client.last_call = time.perf_counter()
                        yield mw_protocols.Command(WAIT=0)
                    dependency, ip = client.dependency_updates.popitem()
                    yield set_dependency_command()
                except Exception:
                    context.cancel()
        else:
            yield mw_protocols.Command(ERROR=rc)
            return
Beispiel #11
0
    def dispatch(self, request_object: Message,
                 context_object: grpc.ServicerContext, func: Callable):
        ctx = contextvars.copy_context()
        context.set(context_object)

        bound_logger = self.logger.bind(peer=context_object.peer())

        with bound_logger.catch():
            try:
                return ctx.run(func, request_object)
            except exceptions.GRPCError as e:
                context_object.set_code(e.code)
                if e.details:
                    context_object.set_details(e.details)
            except NotImplementedError:
                context_object.set_code(grpc.StatusCode.UNIMPLEMENTED)
Beispiel #12
0
    async def Train(self, request: TrainingRequest,
                    context: grpc.ServicerContext) -> Empty:
        """
        Re-train the model on the submitted audio data.

        :param request: training request.
        :param context: RPC context.
        :return: empty response.
        """
        self._check_audio(request.audio_samples, context)
        peer = context.peer()
        uid = uuid.uuid4()
        label = Labels(request.label)

        logger.info(
            f"saving training sample from {peer}: UUID {uid}, label {label}")
        samples = np.array(request.audio_samples, dtype=np.float32)

        destination_directory = self._config.training_upload_path.joinpath(
            label.to_text())
        destination_directory.mkdir(mode=0o770, parents=True, exist_ok=True)

        wavfile.write(
            destination_directory.joinpath(f"{uid}.wav"),
            16000,
            (samples * 32767.0).astype(np.int16),
        )

        untrained_samples = self._persistent_state.increment_untrained_samples(
            1)
        logger.info(
            f"server has collected {untrained_samples} untrained samples")

        if untrained_samples >= self._config.samples_before_train:
            logger.info(
                f"collected a sufficient amount of uploaded training samples: starting training"
            )
            if self._train_and_swap():
                logger.info(f"started training and swapping task")
            else:
                logger.info(f"training and swapping ongoing, skipping")

        return Empty()
Beispiel #13
0
    def Rendezvous(
            self, request: coordinator_pb2.RendezvousRequest,
            context: grpc.ServicerContext) -> coordinator_pb2.RendezvousReply:
        """The Rendezvous gRPC method.

        A participant contacts the coordinator and the coordinator adds the
        participant to its list of participants. If the coordinator already has
        all the participants it tells the participant to try again later.

        Args:
            request (:class:`~.coordinator_pb2.RendezvousRequest`): The participant's request.
            context (:class:`~grpc.ServicerContext`): The context associated with the gRPC request.

        Returns:
            :class:`~.coordinator_pb2.RendezvousReply`: The reply to the
            participant's request. The reply is an enum containing either:

                ACCEPT: If the :class:`xain_fl.coordinator.coordinator.Coordinator`
                    does not have enough participants.
                LATER: If the :class:`xain_fl.coordinator.coordinator.Coordinator`
                    already has enough participants.
        """
        return self.coordinator.on_message(request, context.peer())
Beispiel #14
0
    def SomeReply(self, request, context: grpc.ServicerContext):

        print(context.peer(), flush=True)
        time.sleep(5)
        return common_pb2.SomeMessage(msg='bar')