def run(self): """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """ loop = switch_to_uvloop() # initialize asyncio synchronization primitives in this event loop pipe_awaiter = ThreadPoolExecutor(self.receiver_threads) async def _run(): grpc.aio.init_grpc_aio() server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS) averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server) found_port = server.add_insecure_port(self.listen_on) assert found_port != 0, f"Failed to listen to {self.listen_on}" self._port.value = found_port self._matchmaking = Matchmaking(self.endpoint, self._averaged_tensors, self.dht, **self.matchmaking_kwargs, return_deltas=True) # note: we need deltas to make allreduce lock-free self._pending_group_assembled = asyncio.Event() self._pending_group_assembled.set() await server.start() self.ready.set() while True: method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv) asyncio.create_task(getattr(self, method)(*args, **kwargs)) loop.run_until_complete(_run())
def run(self): torch.set_num_threads(1) loop = switch_to_uvloop() async def _run(): grpc.aio.init_grpc_aio() logger.debug(f'Starting, pid {os.getpid()}') server = grpc.aio.server(options=GRPC_KEEPALIVE_OPTIONS + (('grpc.so_reuseport', 1), ('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1))) runtime_grpc.add_ConnectionHandlerServicer_to_server(self, server) found_port = server.add_insecure_port(self.listen_on) assert found_port != 0, f"Failed to listen to {self.listen_on}" await server.start() self.ready.set() await server.wait_for_termination() logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})") try: loop.run_until_complete(_run()) except KeyboardInterrupt: logger.debug('Caught KeyboardInterrupt, shutting down')
def run(self): """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """ loop = switch_to_uvloop() # initialize asyncio synchronization primitives in this event loop pipe_awaiter = ThreadPoolExecutor(max_workers=1) async def _run(): grpc.aio.init_grpc_aio() if self.listen: server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS) averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server( self, server) found_port = server.add_insecure_port(self.listen_on) assert found_port != 0, f"Failed to listen to {self.listen_on}" self._port.value = found_port await server.start() else: logger.info( f"The averager running in an experimental client mode, please report any bugs." ) self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs, client_mode=not self.listen) if self.listen: asyncio.create_task(self._declare_for_download_periodically()) self._pending_group_assembled = asyncio.Event() self._pending_group_assembled.set() self.ready.set() while True: method, args, kwargs = await loop.run_in_executor( pipe_awaiter, self._pipe.recv) asyncio.create_task(getattr(self, method)(*args, **kwargs)) loop.run_until_complete(_run())