Ejemplo n.º 1
0
def client_loop(
    assembler: tube.ChannelAssembler,
    start_time: float,
    context: tube.Context,
    execution_params: ExecutionParams,
) -> None:
    assembler.start()
    max_time = execution_params.max_time
    while max_time is None or time.time() < start_time + max_time:
        time.sleep(60)
        print("Resource usage:")
        print(utils.get_res_usage_str())
        print("Context stats:")
        print(context.get_stats_str())
Ejemplo n.º 2
0
def _train_epoch(
    train_device: torch.device,
    model: torch.jit.ScriptModule,
    ddpmodel: ModelWrapperForDDP,
    model_path: Path,
    optim: torch.optim.Optimizer,
    assembler: tube.ChannelAssembler,
    stat: utils.MultiCounter,
    epoch: int,
    optim_params: OptimParams,
    sync_period: int,
) -> None:
    global _train_epoch_waiting_time
    pre_num_add = assembler.buffer_num_add()
    pre_num_sample = assembler.buffer_num_sample()
    sync_s = 0.
    num_sync = 0
    t = time.time()
    time.sleep(_train_epoch_waiting_time)
    lossmodel = DDPWrapperForModel(ddpmodel) if ddpmodel is not None else model
    for eid in range(optim_params.epoch_len):
        batch = assembler.sample(optim_params.batchsize)
        batch = utils.to_device(batch, train_device)
        loss = model.loss(lossmodel, batch["s"], batch["v"], batch["pi"],
                          batch["pi_mask"], stat)
        loss.backward()
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(),
                                             optim_params.grad_clip)
        optim.step()
        optim.zero_grad()

        if (epoch * optim_params.epoch_len + eid + 1) % sync_period == 0:
            sync_t0 = time.time()
            assembler.update_model(model.state_dict())
            sync_s += time.time() - sync_t0
            num_sync += 1

        stat["loss"].feed(loss.detach().item())
        stat["grad_norm"].feed(grad_norm)

    post_num_add = assembler.buffer_num_add()
    post_num_sample = assembler.buffer_num_sample()
    time_elapsed = time.time() - t
    delta_add = post_num_add - pre_num_add
    print("buffer add rate: %.2f / s" % (delta_add / time_elapsed))
    delta_sample = post_num_sample - pre_num_sample
    if delta_sample > 8 * delta_add:  # If the sample rate is not at least 8x the add rate, everything is fine.
        _train_epoch_waiting_time += time_elapsed
    else:
        _train_epoch_waiting_time = 0
    print("buffer sample rate: %.2f / s" % (delta_sample / time_elapsed))
    print(
        f"syncing duration: {sync_s:2f}s for {num_sync} syncs ({int(100 * sync_s / time_elapsed)}% of train time)"
    )

    stat.summary(epoch)
    stat.reset()
Ejemplo n.º 3
0
def warm_up_replay_buffer(
    assembler: tube.ChannelAssembler, replay_warmup: int, replay_buffer: Optional[bytes]
) -> None:
    if replay_buffer is not None:
        print("loading replay buffer...")
        assembler.buffer = replay_buffer
    assembler.start()
    prev_buffer_size = -1
    t = t_init = time.time()
    t0 = -1
    size0 = 0
    while assembler.buffer_size() < replay_warmup:
        buffer_size = assembler.buffer_size()
        if buffer_size != prev_buffer_size:  # avoid flooding stdout
            if buffer_size > 10000 and t0 == -1:
                size0 = buffer_size
                t0 = time.time()
            prev_buffer_size = max(prev_buffer_size, 0)
            frame_rate = (buffer_size - prev_buffer_size) / (time.time() - t)
            frame_rate = int(frame_rate)
            prev_buffer_size = buffer_size
            t = time.time()
            duration = t - t_init
            print(
                f"warming-up replay buffer: {(buffer_size * 100) // replay_warmup}% "
                f"({buffer_size}/{replay_warmup}) in {duration:.2f}s "
                f"- speed: {frame_rate} frames/s",
                end="\r",
                flush=True,
            )
        time.sleep(2)
    print(
        f"replay buffer warmed up: 100% "
        f"({assembler.buffer_size()}/{replay_warmup})"
        "                                                                          "
    )
    print(
        "avg speed: %.2f frames/s"
        % ((assembler.buffer_size() - size0) / (time.time() - t0))
    )
Ejemplo n.º 4
0
def train_model(
    command_history: utils.CommandHistory,
    start_time: float,
    train_device: torch.device,
    model: torch.jit.ScriptModule,
    model_path: Path,
    ddpmodel,
    optim: torch.optim.Optimizer,
    context: tube.Context,
    assembler: tube.ChannelAssembler,
    get_train_reward: Callable[[], List[int]],
    game_params: GameParams,
    model_params: ModelParams,
    optim_params: OptimParams,
    simulation_params: SimulationParams,
    execution_params: ExecutionParams,
    epoch: int = 0,
) -> None:
    stat = utils.MultiCounter(execution_params.checkpoint_dir)
    max_time = execution_params.max_time
    init_epoch = epoch
    while max_time is None or time.time() < start_time + max_time:
        if epoch - init_epoch >= optim_params.num_epoch:
            break
        epoch += 1
        if not (epoch - init_epoch) % execution_params.saving_period:
            assembler.add_tournament_model("e%d" % (epoch), model.state_dict())
            utils.save_checkpoint(
                command_history=command_history,
                epoch=epoch,
                model=model,
                optim=optim,
                assembler=assembler,
                game_params=game_params,
                model_params=model_params,
                optim_params=optim_params,
                simulation_params=simulation_params,
                execution_params=execution_params,
            )
        _train_epoch(
            train_device=train_device,
            model=model,
            ddpmodel=ddpmodel,
            model_path=model_path,
            optim=optim,
            assembler=assembler,
            stat=stat,
            epoch=epoch,
            optim_params=optim_params,
            sync_period=simulation_params.sync_period,
        )
        # resource usage stats
        print("Resource usage:")
        print(utils.get_res_usage_str())
        print("Context stats:")
        print(context.get_stats_str())
        # train result
        print(
            ">>>train: epoch: %d, %s" %
            (epoch, utils.Result(get_train_reward()).log()),
            flush=True,
        )
    # checkpoint last state
    utils.save_checkpoint(
        command_history=command_history,
        epoch=epoch,
        model=model,
        optim=optim,
        assembler=assembler,
        game_params=game_params,
        model_params=model_params,
        optim_params=optim_params,
        simulation_params=simulation_params,
        execution_params=execution_params,
    )