示例#1
0
    def run(self):
        """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
        loop = switch_to_uvloop()
        # initialize asyncio synchronization primitives in this event loop
        pipe_awaiter = ThreadPoolExecutor(self.receiver_threads)

        async def _run():
            grpc.aio.init_grpc_aio()
            server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
            averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
            found_port = server.add_insecure_port(self.listen_on)
            assert found_port != 0, f"Failed to listen to {self.listen_on}"
            self._port.value = found_port
            self._matchmaking = Matchmaking(self.endpoint, self._averaged_tensors, self.dht, **self.matchmaking_kwargs,
                                            return_deltas=True)  # note: we need deltas to make allreduce lock-free
            self._pending_group_assembled = asyncio.Event()
            self._pending_group_assembled.set()
            await server.start()
            self.ready.set()

            while True:
                method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
                asyncio.create_task(getattr(self, method)(*args, **kwargs))

        loop.run_until_complete(_run())
示例#2
0
    def run(self):
        torch.set_num_threads(1)
        loop = switch_to_uvloop()

        async def _run():
            grpc.aio.init_grpc_aio()
            logger.debug(f'Starting, pid {os.getpid()}')
            server = grpc.aio.server(options=GRPC_KEEPALIVE_OPTIONS +
                                     (('grpc.so_reuseport', 1),
                                      ('grpc.max_send_message_length', -1),
                                      ('grpc.max_receive_message_length', -1)))
            runtime_grpc.add_ConnectionHandlerServicer_to_server(self, server)

            found_port = server.add_insecure_port(self.listen_on)
            assert found_port != 0, f"Failed to listen to {self.listen_on}"

            await server.start()
            self.ready.set()
            await server.wait_for_termination()
            logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})")

        try:
            loop.run_until_complete(_run())
        except KeyboardInterrupt:
            logger.debug('Caught KeyboardInterrupt, shutting down')
示例#3
0
    def run(self):
        """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
        loop = switch_to_uvloop()
        # initialize asyncio synchronization primitives in this event loop
        pipe_awaiter = ThreadPoolExecutor(max_workers=1)

        async def _run():
            grpc.aio.init_grpc_aio()

            if self.listen:
                server = grpc.aio.server(**self.kwargs,
                                         options=GRPC_KEEPALIVE_OPTIONS)
                averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(
                    self, server)
                found_port = server.add_insecure_port(self.listen_on)
                assert found_port != 0, f"Failed to listen to {self.listen_on}"
                self._port.value = found_port
                await server.start()
            else:
                logger.info(
                    f"The averager running in an experimental client mode, please report any bugs."
                )

            self._matchmaking = Matchmaking(self.endpoint,
                                            self.schema_hash,
                                            self.dht,
                                            **self.matchmaking_kwargs,
                                            client_mode=not self.listen)
            if self.listen:
                asyncio.create_task(self._declare_for_download_periodically())

            self._pending_group_assembled = asyncio.Event()
            self._pending_group_assembled.set()
            self.ready.set()

            while True:
                method, args, kwargs = await loop.run_in_executor(
                    pipe_awaiter, self._pipe.recv)
                asyncio.create_task(getattr(self, method)(*args, **kwargs))

        loop.run_until_complete(_run())