Esempio n. 1
0
def _play_game_against_neural_mcts(
    devices: List[torch.device],
    models: List[torch.jit.ScriptModule],
    context: tube.Context,
    actor_channel: tube.DataChannel,
) -> None:
    nb_devices = len(devices)
    context.start()
    dcm = DataChannelManager([actor_channel])
    while not context.terminated():
        batch = dcm.get_input(max_timeout_s=1)
        if len(batch) == 0:
            continue
        assert len(batch) == 1

        # split in as many part as there are devices
        batches_s = torch.chunk(batch[actor_channel.name]["s"],
                                nb_devices,
                                dim=0)
        futures = []
        reply_eval = {"v": None, "pi": None}
        # multithread
        with ThreadPoolExecutor(max_workers=nb_devices) as executor:
            for device, model, batch_s in zip(devices, models, batches_s):
                futures.append(
                    executor.submit(_forward_pass_on_device, device, model,
                                    batch_s))
            results = [future.result() for future in futures]
            reply_eval["v"] = torch.cat([result["v"] for result in results],
                                        dim=0)
            reply_eval["pi"] = torch.cat([result["pi"] for result in results],
                                         dim=0)
        dcm.set_reply(actor_channel.name, reply_eval)
    dcm.terminate()
Esempio n. 2
0
def client_loop(
    model_manager: polygames.ModelManager,
    start_time: float,
    context: tube.Context,
    execution_params: ExecutionParams,
) -> None:
    model_manager.start()
    max_time = execution_params.max_time
    while max_time is None or time.time() < start_time + max_time:
        time.sleep(60)
        print("Resource usage:")
        print(utils.get_res_usage_str())
        print("Context stats:")
        print(context.get_stats_str())
Esempio n. 3
0
def client_loop(
    assembler: tube.ChannelAssembler,
    start_time: float,
    context: tube.Context,
    execution_params: ExecutionParams,
) -> None:
    assembler.start()
    max_time = execution_params.max_time
    while max_time is None or time.time() < start_time + max_time:
        time.sleep(60)
        print("Resource usage:")
        print(utils.get_res_usage_str())
        print("Context stats:")
        print(context.get_stats_str())
Esempio n. 4
0
def train_model(
    command_history: utils.CommandHistory,
    start_time: float,
    model: torch.jit.ScriptModule,
    device: torch.device,
    ddpmodel,
    optim: torch.optim.Optimizer,
    context: tube.Context,
    model_manager: polygames.ModelManager,
    get_train_reward: Callable[[], List[int]],
    game_params: GameParams,
    model_params: ModelParams,
    optim_params: OptimParams,
    simulation_params: SimulationParams,
    execution_params: ExecutionParams,
    epoch: int = 0,
) -> None:

    info = zutils.get_game_info(game_params)
    c, h, w = info["feature_size"][:3]
    rc, rh, rw = info["raw_feature_size"][:3]
    c_prime, h_prime, w_prime = info["action_size"][:3]

    predicts = (2 if game_params.predict_end_state else
                0) + game_params.predict_n_states

    batchsizes = {
        "s": [c, h, w],
        "v": [3 if getattr(model, "logit_value", False) else 1],
        "pred_v": [1],
        "pi": [c_prime, h_prime, w_prime],
        "pi_mask": [c_prime, h_prime, w_prime]
    }

    if game_params.player == "forward":
        batchsizes["action_pi"] = [c_prime, h_prime, w_prime]

    if predicts > 0:
        batchsizes["predict_pi"] = [rc * predicts, rh, rw]
        batchsizes["predict_pi_mask"] = [rc * predicts, rh, rw]

    if getattr(model, "rnn_state_shape", None) is not None:
        batchsizes["rnn_state_mask"] = [1]

    if execution_params.rnn_seqlen > 0:
        for k, v in batchsizes.items():
            batchsizes[k] = [execution_params.rnn_seqlen, *v]

    if getattr(model, "rnn_state_shape", None) is not None:
        batchsizes["rnn_initial_state"] = model.rnn_state_shape

    rank = 0
    if ddpmodel:
        rank = torch.distributed.get_rank()

    executor = ThreadPoolExecutor(max_workers=1)
    savefuture = None

    stat = utils.MultiCounter(execution_params.checkpoint_dir)
    max_time = execution_params.max_time
    init_epoch = epoch
    while max_time is None or time.time() < start_time + max_time:
        if epoch - init_epoch >= optim_params.num_epoch:
            break
        epoch += 1
        if rank == 0 and epoch % execution_params.saving_period == 0:
            model_manager.add_tournament_model("e%d" % (epoch),
                                               model.state_dict())
            savestart = time.time()
            if savefuture is not None:
                savefuture.result()
            savefuture = utils.save_checkpoint(
                command_history=command_history,
                epoch=epoch,
                model=model,
                optim=optim,
                game_params=game_params,
                model_params=model_params,
                optim_params=optim_params,
                simulation_params=simulation_params,
                execution_params=execution_params,
                executor=executor)
            print("checkpoint saved in %gs" % (time.time() - savestart))
        _train_epoch(
            model=model,
            device=device,
            ddpmodel=ddpmodel,
            batchsizes=batchsizes,
            optim=optim,
            model_manager=model_manager,
            stat=stat,
            epoch=epoch,
            optim_params=optim_params,
            sync_period=simulation_params.sync_period,
        )
        # resource usage stats
        print("Resource usage:")
        print(utils.get_res_usage_str())
        print("Context stats:")
        print(context.get_stats_str())
        # train result
        print(
            ">>>train: epoch: %d, %s" %
            (epoch, utils.Result(get_train_reward()).log()),
            flush=True,
        )
    if savefuture is not None:
        savefuture.result()
    # checkpoint last state
    utils.save_checkpoint(
        command_history=command_history,
        epoch=epoch,
        model=model,
        optim=optim,
        game_params=game_params,
        model_params=model_params,
        optim_params=optim_params,
        simulation_params=simulation_params,
        execution_params=execution_params,
    )
Esempio n. 5
0
def train_model(
    command_history: utils.CommandHistory,
    start_time: float,
    train_device: torch.device,
    model: torch.jit.ScriptModule,
    model_path: Path,
    ddpmodel,
    optim: torch.optim.Optimizer,
    context: tube.Context,
    assembler: tube.ChannelAssembler,
    get_train_reward: Callable[[], List[int]],
    game_params: GameParams,
    model_params: ModelParams,
    optim_params: OptimParams,
    simulation_params: SimulationParams,
    execution_params: ExecutionParams,
    epoch: int = 0,
) -> None:
    stat = utils.MultiCounter(execution_params.checkpoint_dir)
    max_time = execution_params.max_time
    init_epoch = epoch
    while max_time is None or time.time() < start_time + max_time:
        if epoch - init_epoch >= optim_params.num_epoch:
            break
        epoch += 1
        if not (epoch - init_epoch) % execution_params.saving_period:
            assembler.add_tournament_model("e%d" % (epoch), model.state_dict())
            utils.save_checkpoint(
                command_history=command_history,
                epoch=epoch,
                model=model,
                optim=optim,
                assembler=assembler,
                game_params=game_params,
                model_params=model_params,
                optim_params=optim_params,
                simulation_params=simulation_params,
                execution_params=execution_params,
            )
        _train_epoch(
            train_device=train_device,
            model=model,
            ddpmodel=ddpmodel,
            model_path=model_path,
            optim=optim,
            assembler=assembler,
            stat=stat,
            epoch=epoch,
            optim_params=optim_params,
            sync_period=simulation_params.sync_period,
        )
        # resource usage stats
        print("Resource usage:")
        print(utils.get_res_usage_str())
        print("Context stats:")
        print(context.get_stats_str())
        # train result
        print(
            ">>>train: epoch: %d, %s" %
            (epoch, utils.Result(get_train_reward()).log()),
            flush=True,
        )
    # checkpoint last state
    utils.save_checkpoint(
        command_history=command_history,
        epoch=epoch,
        model=model,
        optim=optim,
        assembler=assembler,
        game_params=game_params,
        model_params=model_params,
        optim_params=optim_params,
        simulation_params=simulation_params,
        execution_params=execution_params,
    )
Esempio n. 6
0
def _play_game_against_mcts(context: tube.Context) -> None:
    context.start()
    while not context.terminated():
        time.sleep(1)