def client_loop( assembler: tube.ChannelAssembler, start_time: float, context: tube.Context, execution_params: ExecutionParams, ) -> None: assembler.start() max_time = execution_params.max_time while max_time is None or time.time() < start_time + max_time: time.sleep(60) print("Resource usage:") print(utils.get_res_usage_str()) print("Context stats:") print(context.get_stats_str())
def _train_epoch( train_device: torch.device, model: torch.jit.ScriptModule, ddpmodel: ModelWrapperForDDP, model_path: Path, optim: torch.optim.Optimizer, assembler: tube.ChannelAssembler, stat: utils.MultiCounter, epoch: int, optim_params: OptimParams, sync_period: int, ) -> None: global _train_epoch_waiting_time pre_num_add = assembler.buffer_num_add() pre_num_sample = assembler.buffer_num_sample() sync_s = 0. num_sync = 0 t = time.time() time.sleep(_train_epoch_waiting_time) lossmodel = DDPWrapperForModel(ddpmodel) if ddpmodel is not None else model for eid in range(optim_params.epoch_len): batch = assembler.sample(optim_params.batchsize) batch = utils.to_device(batch, train_device) loss = model.loss(lossmodel, batch["s"], batch["v"], batch["pi"], batch["pi_mask"], stat) loss.backward() grad_norm = nn.utils.clip_grad_norm_(model.parameters(), optim_params.grad_clip) optim.step() optim.zero_grad() if (epoch * optim_params.epoch_len + eid + 1) % sync_period == 0: sync_t0 = time.time() assembler.update_model(model.state_dict()) sync_s += time.time() - sync_t0 num_sync += 1 stat["loss"].feed(loss.detach().item()) stat["grad_norm"].feed(grad_norm) post_num_add = assembler.buffer_num_add() post_num_sample = assembler.buffer_num_sample() time_elapsed = time.time() - t delta_add = post_num_add - pre_num_add print("buffer add rate: %.2f / s" % (delta_add / time_elapsed)) delta_sample = post_num_sample - pre_num_sample if delta_sample > 8 * delta_add: # If the sample rate is not at least 8x the add rate, everything is fine. _train_epoch_waiting_time += time_elapsed else: _train_epoch_waiting_time = 0 print("buffer sample rate: %.2f / s" % (delta_sample / time_elapsed)) print( f"syncing duration: {sync_s:2f}s for {num_sync} syncs ({int(100 * sync_s / time_elapsed)}% of train time)" ) stat.summary(epoch) stat.reset()
def warm_up_replay_buffer( assembler: tube.ChannelAssembler, replay_warmup: int, replay_buffer: Optional[bytes] ) -> None: if replay_buffer is not None: print("loading replay buffer...") assembler.buffer = replay_buffer assembler.start() prev_buffer_size = -1 t = t_init = time.time() t0 = -1 size0 = 0 while assembler.buffer_size() < replay_warmup: buffer_size = assembler.buffer_size() if buffer_size != prev_buffer_size: # avoid flooding stdout if buffer_size > 10000 and t0 == -1: size0 = buffer_size t0 = time.time() prev_buffer_size = max(prev_buffer_size, 0) frame_rate = (buffer_size - prev_buffer_size) / (time.time() - t) frame_rate = int(frame_rate) prev_buffer_size = buffer_size t = time.time() duration = t - t_init print( f"warming-up replay buffer: {(buffer_size * 100) // replay_warmup}% " f"({buffer_size}/{replay_warmup}) in {duration:.2f}s " f"- speed: {frame_rate} frames/s", end="\r", flush=True, ) time.sleep(2) print( f"replay buffer warmed up: 100% " f"({assembler.buffer_size()}/{replay_warmup})" " " ) print( "avg speed: %.2f frames/s" % ((assembler.buffer_size() - size0) / (time.time() - t0)) )
def train_model( command_history: utils.CommandHistory, start_time: float, train_device: torch.device, model: torch.jit.ScriptModule, model_path: Path, ddpmodel, optim: torch.optim.Optimizer, context: tube.Context, assembler: tube.ChannelAssembler, get_train_reward: Callable[[], List[int]], game_params: GameParams, model_params: ModelParams, optim_params: OptimParams, simulation_params: SimulationParams, execution_params: ExecutionParams, epoch: int = 0, ) -> None: stat = utils.MultiCounter(execution_params.checkpoint_dir) max_time = execution_params.max_time init_epoch = epoch while max_time is None or time.time() < start_time + max_time: if epoch - init_epoch >= optim_params.num_epoch: break epoch += 1 if not (epoch - init_epoch) % execution_params.saving_period: assembler.add_tournament_model("e%d" % (epoch), model.state_dict()) utils.save_checkpoint( command_history=command_history, epoch=epoch, model=model, optim=optim, assembler=assembler, game_params=game_params, model_params=model_params, optim_params=optim_params, simulation_params=simulation_params, execution_params=execution_params, ) _train_epoch( train_device=train_device, model=model, ddpmodel=ddpmodel, model_path=model_path, optim=optim, assembler=assembler, stat=stat, epoch=epoch, optim_params=optim_params, sync_period=simulation_params.sync_period, ) # resource usage stats print("Resource usage:") print(utils.get_res_usage_str()) print("Context stats:") print(context.get_stats_str()) # train result print( ">>>train: epoch: %d, %s" % (epoch, utils.Result(get_train_reward()).log()), flush=True, ) # checkpoint last state utils.save_checkpoint( command_history=command_history, epoch=epoch, model=model, optim=optim, assembler=assembler, game_params=game_params, model_params=model_params, optim_params=optim_params, simulation_params=simulation_params, execution_params=execution_params, )