def _play_game_against_neural_mcts( devices: List[torch.device], models: List[torch.jit.ScriptModule], context: tube.Context, actor_channel: tube.DataChannel, ) -> None: nb_devices = len(devices) context.start() dcm = DataChannelManager([actor_channel]) while not context.terminated(): batch = dcm.get_input(max_timeout_s=1) if len(batch) == 0: continue assert len(batch) == 1 # split in as many part as there are devices batches_s = torch.chunk(batch[actor_channel.name]["s"], nb_devices, dim=0) futures = [] reply_eval = {"v": None, "pi": None} # multithread with ThreadPoolExecutor(max_workers=nb_devices) as executor: for device, model, batch_s in zip(devices, models, batches_s): futures.append( executor.submit(_forward_pass_on_device, device, model, batch_s)) results = [future.result() for future in futures] reply_eval["v"] = torch.cat([result["v"] for result in results], dim=0) reply_eval["pi"] = torch.cat([result["pi"] for result in results], dim=0) dcm.set_reply(actor_channel.name, reply_eval) dcm.terminate()
def _play_game_against_mcts(context: tube.Context) -> None: context.start() while not context.terminated(): time.sleep(1)