def _get_server_config(self): while True: try: channel = grpc.insecure_channel(FLAGS.server_address) self.stub = inference_service_pb2_grpc.InferenceServiceStub( channel) config = self.stub.GetConfig( inference_service_pb2.GetConfigRequest()) break except grpc.RpcError: dbg("Waiting for server") time.sleep(1) if config.board_size != go.N: raise RuntimeError("Board size mismatch: server=%d, worker=%d" % (config.board_size, go.N)) self.positions_per_inference = (config.games_per_inference * config.virtual_losses) if self.positions_per_inference % self.parallel_inferences != 0: raise RuntimeError( "games_per_inference * virtual_losses must be divisible by " "parallel_tpus") self.batch_size = self.positions_per_inference // self.parallel_inferences dbg("parallel_inferences = %d" % self.parallel_inferences) dbg("games_per_inference = %d" % config.games_per_inference) dbg("virtual_losses = %d" % config.virtual_losses) dbg("positions_per_inference = %d" % self.positions_per_inference) dbg("batch_size = %d" % self.batch_size)
def get_server_config(): """Connects to the inference server and fetches its configuration. Returns: Server's configuration as a inference_service_pb2.GetConfigResponse proto. """ while True: try: # Fetch the server config, used to set batch size. channel = grpc.insecure_channel(FLAGS.address) stub = inference_service_pb2_grpc.InferenceServiceStub(channel) return stub.GetConfig(inference_service_pb2.GetConfigRequest()) except Exception: # pylint: disable=broad-except print("Waiting for server") time.sleep(1)