def _train_epoch( train_device: torch.device, model: torch.jit.ScriptModule, ddpmodel: ModelWrapperForDDP, model_path: Path, optim: torch.optim.Optimizer, assembler: tube.ChannelAssembler, stat: utils.MultiCounter, epoch: int, optim_params: OptimParams, sync_period: int, ) -> None: global _train_epoch_waiting_time pre_num_add = assembler.buffer_num_add() pre_num_sample = assembler.buffer_num_sample() sync_s = 0. num_sync = 0 t = time.time() time.sleep(_train_epoch_waiting_time) lossmodel = DDPWrapperForModel(ddpmodel) if ddpmodel is not None else model for eid in range(optim_params.epoch_len): batch = assembler.sample(optim_params.batchsize) batch = utils.to_device(batch, train_device) loss = model.loss(lossmodel, batch["s"], batch["v"], batch["pi"], batch["pi_mask"], stat) loss.backward() grad_norm = nn.utils.clip_grad_norm_(model.parameters(), optim_params.grad_clip) optim.step() optim.zero_grad() if (epoch * optim_params.epoch_len + eid + 1) % sync_period == 0: sync_t0 = time.time() assembler.update_model(model.state_dict()) sync_s += time.time() - sync_t0 num_sync += 1 stat["loss"].feed(loss.detach().item()) stat["grad_norm"].feed(grad_norm) post_num_add = assembler.buffer_num_add() post_num_sample = assembler.buffer_num_sample() time_elapsed = time.time() - t delta_add = post_num_add - pre_num_add print("buffer add rate: %.2f / s" % (delta_add / time_elapsed)) delta_sample = post_num_sample - pre_num_sample if delta_sample > 8 * delta_add: # If the sample rate is not at least 8x the add rate, everything is fine. _train_epoch_waiting_time += time_elapsed else: _train_epoch_waiting_time = 0 print("buffer sample rate: %.2f / s" % (delta_sample / time_elapsed)) print( f"syncing duration: {sync_s:2f}s for {num_sync} syncs ({int(100 * sync_s / time_elapsed)}% of train time)" ) stat.summary(epoch) stat.reset()
def save_checkpoint( command_history: CommandHistory, epoch: int, model: torch.jit.ScriptModule, optim: torch.optim.Optimizer, game_params: GameParams, model_params: ModelParams, optim_params: OptimParams, simulation_params: SimulationParams, execution_params: ExecutionParams, executor: ThreadPoolExecutor = None, ) -> None: checkpoint_dir = execution_params.checkpoint_dir save_uncompressed = execution_params.save_uncompressed checkpoint_name = f"checkpoint_{epoch}" checkpoint = { "command_history": command_history, "epoch": epoch, "model_state_dict": { k: v.cpu().clone() if isinstance(v, torch.Tensor) else copy.deepcopy(v) for k, v in model.state_dict().items() }, "optim_state_dict": { k: v.cpu().clone() if isinstance(v, torch.Tensor) else copy.deepcopy(v) for k, v in optim.state_dict().items() }, "game_params": game_params, "model_params": model_params, "optim_params": optim_params, "simulation_params": simulation_params, "execution_params": execution_params, } def saveit(): nonlocal save_uncompressed nonlocal checkpoint nonlocal checkpoint_dir if save_uncompressed: torch.save(checkpoint, checkpoint_dir / f"{checkpoint_name}.pt") else: # with zipfile.ZipFile(Path(checkpoint_dir) / f"{checkpoint_name}.zip", "w", allowZip64=True) as z: # with z.open(f"{checkpoint_name}.pt", "w", force_zip64=True) as f: # torch.save(checkpoint, f) with gzip.open(checkpoint_dir / f"{checkpoint_name}.pt.gz", "wb") as f: torch.save(checkpoint, f) if executor is not None: return executor.submit(saveit) else: saveit()
def save_checkpoint( command_history: CommandHistory, epoch: int, model: torch.jit.ScriptModule, optim: torch.optim.Optimizer, assembler: tube.ChannelAssembler, game_params: GameParams, model_params: ModelParams, optim_params: OptimParams, simulation_params: SimulationParams, execution_params: ExecutionParams, ) -> None: checkpoint_dir = execution_params.checkpoint_dir save_uncompressed = execution_params.save_uncompressed do_not_save_replay_buffer = execution_params.do_not_save_replay_buffer checkpoint_name = f"checkpoint_{epoch}" checkpoint = { "command_history": command_history, "epoch": epoch, "model_state_dict": model.state_dict(), "optim_state_dict": optim.state_dict(), "game_params": game_params, "model_params": model_params, "optim_params": optim_params, "simulation_params": simulation_params, "execution_params": execution_params, } if not do_not_save_replay_buffer: checkpoint.update({"replay_buffer": assembler.buffer}) if save_uncompressed: torch.save(checkpoint, checkpoint_dir / f"{checkpoint_name}.pt") else: # with zipfile.ZipFile(Path(checkpoint_dir) / f"{checkpoint_name}.zip", "w", allowZip64=True) as z: # with z.open(f"{checkpoint_name}.pt", "w", force_zip64=True) as f: # torch.save(checkpoint, f) with gzip.open(checkpoint_dir / f"temp_{checkpoint_name}.pt.gz", "wb") as f: torch.save(checkpoint, f) os.rename(checkpoint_dir / f"temp_{checkpoint_name}.pt.gz", checkpoint_dir / f"{checkpoint_name}.pt.gz")
def train_model( command_history: utils.CommandHistory, start_time: float, model: torch.jit.ScriptModule, device: torch.device, ddpmodel, optim: torch.optim.Optimizer, context: tube.Context, model_manager: polygames.ModelManager, get_train_reward: Callable[[], List[int]], game_params: GameParams, model_params: ModelParams, optim_params: OptimParams, simulation_params: SimulationParams, execution_params: ExecutionParams, epoch: int = 0, ) -> None: info = zutils.get_game_info(game_params) c, h, w = info["feature_size"][:3] rc, rh, rw = info["raw_feature_size"][:3] c_prime, h_prime, w_prime = info["action_size"][:3] predicts = (2 if game_params.predict_end_state else 0) + game_params.predict_n_states batchsizes = { "s": [c, h, w], "v": [3 if getattr(model, "logit_value", False) else 1], "pred_v": [1], "pi": [c_prime, h_prime, w_prime], "pi_mask": [c_prime, h_prime, w_prime] } if game_params.player == "forward": batchsizes["action_pi"] = [c_prime, h_prime, w_prime] if predicts > 0: batchsizes["predict_pi"] = [rc * predicts, rh, rw] batchsizes["predict_pi_mask"] = [rc * predicts, rh, rw] if getattr(model, "rnn_state_shape", None) is not None: batchsizes["rnn_state_mask"] = [1] if execution_params.rnn_seqlen > 0: for k, v in batchsizes.items(): batchsizes[k] = [execution_params.rnn_seqlen, *v] if getattr(model, "rnn_state_shape", None) is not None: batchsizes["rnn_initial_state"] = model.rnn_state_shape rank = 0 if ddpmodel: rank = torch.distributed.get_rank() executor = ThreadPoolExecutor(max_workers=1) savefuture = None stat = utils.MultiCounter(execution_params.checkpoint_dir) max_time = execution_params.max_time init_epoch = epoch while max_time is None or time.time() < start_time + max_time: if epoch - init_epoch >= optim_params.num_epoch: break epoch += 1 if rank == 0 and epoch % execution_params.saving_period == 0: model_manager.add_tournament_model("e%d" % (epoch), model.state_dict()) savestart = time.time() if savefuture is not None: savefuture.result() savefuture = utils.save_checkpoint( command_history=command_history, epoch=epoch, model=model, optim=optim, game_params=game_params, model_params=model_params, optim_params=optim_params, simulation_params=simulation_params, execution_params=execution_params, executor=executor) print("checkpoint saved in %gs" % (time.time() - savestart)) _train_epoch( model=model, device=device, ddpmodel=ddpmodel, batchsizes=batchsizes, optim=optim, model_manager=model_manager, stat=stat, epoch=epoch, optim_params=optim_params, sync_period=simulation_params.sync_period, ) # resource usage stats print("Resource usage:") print(utils.get_res_usage_str()) print("Context stats:") print(context.get_stats_str()) # train result print( ">>>train: epoch: %d, %s" % (epoch, utils.Result(get_train_reward()).log()), flush=True, ) if savefuture is not None: savefuture.result() # checkpoint last state utils.save_checkpoint( command_history=command_history, epoch=epoch, model=model, optim=optim, game_params=game_params, model_params=model_params, optim_params=optim_params, simulation_params=simulation_params, execution_params=execution_params, )
def _train_epoch( model: torch.jit.ScriptModule, device: torch.device, ddpmodel: ModelWrapperForDDP, batchsizes, optim: torch.optim.Optimizer, model_manager: polygames.ModelManager, stat: utils.MultiCounter, epoch: int, optim_params: OptimParams, sync_period: int, ) -> None: global _pre_num_add global _pre_num_sample global _running_add_rate global _running_sample_rate global _last_train_time global _remote_replay_buffer_inited if _pre_num_add is None: pre_num_add = model_manager.buffer_num_add() pre_num_sample = model_manager.buffer_num_sample() else: pre_num_add = _pre_num_add pre_num_sample = _pre_num_sample sync_s = 0. num_sync = 0 train_start_time = time.time() if pre_num_sample > 0: print("sample/add ratio ", float(pre_num_sample) / pre_num_add) if _last_train_time == 0: _last_train_time = time.time() batchsize = optim_params.batchsize lossmodel = DDPWrapperForModel(ddpmodel) if ddpmodel is not None else model lossmodel.train() world_size = 0 rank = 0 if ddpmodel is not None: print("DDP is active") world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() print("World size %d, rank %d. Waiting for all processes" % (world_size, rank)) torch.distributed.barrier() print("Synchronizing model") for p in ddpmodel.parameters(): torch.distributed.broadcast(p.data, 0) for p in ddpmodel.buffers(): torch.distributed.broadcast(p.data, 0) print("Synchronized, start training") has_predict = False cpubatch = {} for k, v in batchsizes.items(): sizes = v.copy() sizes.insert(0, batchsize) cpubatch[k] = torch.empty(sizes) if k == "predict_pi": has_predict = True for eid in range(optim_params.epoch_len): while _running_add_rate * 1.25 < _running_sample_rate: print("add rate insufficient, waiting") time.sleep(5) t = time.time() time_elapsed = t - _last_train_time _last_train_time = t alpha = pow(0.99, time_elapsed) post_num_add = model_manager.buffer_num_add() post_num_sample = model_manager.buffer_num_sample() delta_add = post_num_add - pre_num_add delta_sample = post_num_sample - pre_num_sample _running_add_rate = _running_add_rate * alpha + ( delta_add / time_elapsed) * (1 - alpha) _running_sample_rate = _running_sample_rate * alpha + ( delta_sample / time_elapsed) * (1 - alpha) pre_num_add = post_num_add pre_num_sample = post_num_sample print("running add rate: %.2f / s" % (_running_add_rate)) print("running sample rate: %.2f / s" % (_running_sample_rate)) print("current add rate: %.2f / s" % (delta_add / time_elapsed)) print("current sample rate: %.2f / s" % (delta_sample / time_elapsed)) if world_size > 0: batchlist = None if rank == 0: batchlist = {} for k in cpubatch.keys(): batchlist[k] = [] for i in range(world_size): for k, v in model_manager.sample(batchsize).items(): batchlist[k].append(v) for k, v in cpubatch.items(): torch.distributed.scatter(v, batchlist[k] if rank == 0 else None) batch = utils.to_device(cpubatch, device) else: batch = model_manager.sample(batchsize) batch = utils.to_device(batch, device) for k, v in batch.items(): batch[k] = v.detach() loss, v_err, pi_err, predict_err = model_loss.mcts_loss( model, lossmodel, batch) loss.backward() grad_norm = nn.utils.clip_grad_norm_(lossmodel.parameters(), optim_params.grad_clip) optim.step() optim.zero_grad() stat["v_err"].feed(v_err.item()) stat["pi_err"].feed(pi_err.item()) if has_predict: stat["predict_err"].feed(predict_err.item()) stat["loss"].feed(loss.item()) stat["grad_norm"].feed(grad_norm) if (epoch * optim_params.epoch_len + eid + 1) % sync_period == 0: sync_t0 = time.time() model_manager.update_model(model.state_dict()) sync_s += time.time() - sync_t0 num_sync += 1 t = time.time() time_elapsed = t - _last_train_time _last_train_time = t alpha = pow(0.99, time_elapsed) post_num_add = model_manager.buffer_num_add() post_num_sample = model_manager.buffer_num_sample() delta_add = post_num_add - pre_num_add delta_sample = post_num_sample - pre_num_sample _running_add_rate = _running_add_rate * alpha + ( delta_add / time_elapsed) * (1 - alpha) _running_sample_rate = _running_sample_rate * alpha + ( delta_sample / time_elapsed) * (1 - alpha) pre_num_add = post_num_add pre_num_sample = post_num_sample total_time_elapsed = time.time() - train_start_time print("running add rate: %.2f / s" % (_running_add_rate)) print("running sample rate: %.2f / s" % (_running_sample_rate)) print("current add rate: %.2f / s" % (delta_add / time_elapsed)) print("current sample rate: %.2f / s" % (delta_sample / time_elapsed)) print( f"syncing duration: {sync_s:2f}s for {num_sync} syncs ({int(100 * sync_s / total_time_elapsed)}% of train time)" ) _pre_num_add = pre_num_add _pre_num_sample = pre_num_sample stat.summary(epoch) stat.reset()
def train_model( command_history: utils.CommandHistory, start_time: float, train_device: torch.device, model: torch.jit.ScriptModule, model_path: Path, ddpmodel, optim: torch.optim.Optimizer, context: tube.Context, assembler: tube.ChannelAssembler, get_train_reward: Callable[[], List[int]], game_params: GameParams, model_params: ModelParams, optim_params: OptimParams, simulation_params: SimulationParams, execution_params: ExecutionParams, epoch: int = 0, ) -> None: stat = utils.MultiCounter(execution_params.checkpoint_dir) max_time = execution_params.max_time init_epoch = epoch while max_time is None or time.time() < start_time + max_time: if epoch - init_epoch >= optim_params.num_epoch: break epoch += 1 if not (epoch - init_epoch) % execution_params.saving_period: assembler.add_tournament_model("e%d" % (epoch), model.state_dict()) utils.save_checkpoint( command_history=command_history, epoch=epoch, model=model, optim=optim, assembler=assembler, game_params=game_params, model_params=model_params, optim_params=optim_params, simulation_params=simulation_params, execution_params=execution_params, ) _train_epoch( train_device=train_device, model=model, ddpmodel=ddpmodel, model_path=model_path, optim=optim, assembler=assembler, stat=stat, epoch=epoch, optim_params=optim_params, sync_period=simulation_params.sync_period, ) # resource usage stats print("Resource usage:") print(utils.get_res_usage_str()) print("Context stats:") print(context.get_stats_str()) # train result print( ">>>train: epoch: %d, %s" % (epoch, utils.Result(get_train_reward()).log()), flush=True, ) # checkpoint last state utils.save_checkpoint( command_history=command_history, epoch=epoch, model=model, optim=optim, assembler=assembler, game_params=game_params, model_params=model_params, optim_params=optim_params, simulation_params=simulation_params, execution_params=execution_params, )