def _train_epoch( train_device: torch.device, model: torch.jit.ScriptModule, ddpmodel: ModelWrapperForDDP, model_path: Path, optim: torch.optim.Optimizer, assembler: tube.ChannelAssembler, stat: utils.MultiCounter, epoch: int, optim_params: OptimParams, sync_period: int, ) -> None: global _train_epoch_waiting_time pre_num_add = assembler.buffer_num_add() pre_num_sample = assembler.buffer_num_sample() sync_s = 0. num_sync = 0 t = time.time() time.sleep(_train_epoch_waiting_time) lossmodel = DDPWrapperForModel(ddpmodel) if ddpmodel is not None else model for eid in range(optim_params.epoch_len): batch = assembler.sample(optim_params.batchsize) batch = utils.to_device(batch, train_device) loss = model.loss(lossmodel, batch["s"], batch["v"], batch["pi"], batch["pi_mask"], stat) loss.backward() grad_norm = nn.utils.clip_grad_norm_(model.parameters(), optim_params.grad_clip) optim.step() optim.zero_grad() if (epoch * optim_params.epoch_len + eid + 1) % sync_period == 0: sync_t0 = time.time() assembler.update_model(model.state_dict()) sync_s += time.time() - sync_t0 num_sync += 1 stat["loss"].feed(loss.detach().item()) stat["grad_norm"].feed(grad_norm) post_num_add = assembler.buffer_num_add() post_num_sample = assembler.buffer_num_sample() time_elapsed = time.time() - t delta_add = post_num_add - pre_num_add print("buffer add rate: %.2f / s" % (delta_add / time_elapsed)) delta_sample = post_num_sample - pre_num_sample if delta_sample > 8 * delta_add: # If the sample rate is not at least 8x the add rate, everything is fine. _train_epoch_waiting_time += time_elapsed else: _train_epoch_waiting_time = 0 print("buffer sample rate: %.2f / s" % (delta_sample / time_elapsed)) print( f"syncing duration: {sync_s:2f}s for {num_sync} syncs ({int(100 * sync_s / time_elapsed)}% of train time)" ) stat.summary(epoch) stat.reset()
def create_optimizer( model: torch.jit.ScriptModule, optim_params: OptimParams, optim_state_dict: Optional[dict] = None, ) -> torch.optim.Optimizer: optim = torch.optim.Adam( model.parameters(), lr=optim_params.lr, eps=optim_params.eps ) if optim_state_dict is not None: optim.load_state_dict(optim_state_dict) return optim
def create_optimizer( model: torch.jit.ScriptModule, optim_params: OptimParams, optim_state_dict: Optional[dict] = None, ) -> torch.optim.Optimizer: optim = torch.optim.Adam(model.parameters(), lr=optim_params.lr, eps=optim_params.eps) if optim_state_dict is not None and not optim_params.reset_optimizer_state: try: optim.load_state_dict(optim_state_dict) except ValueError: print("Optimizer state not compatible... skipping.") return optim