def save_models(self, cnt, state_dict, model_name): save_dict = { "cnt": cnt, "models": { "MNMT": state_dict, "tgt_DAMSM": None if self.tgt_DAMSM_RNN is None else get_state_dict( self.tgt_DAMSM_RNN), "netG": get_state_dict(self.netG), }, "optims": { "MNMT": self.MNMT_optimizer.state_dict(), "tgt_DAMSM": None if self.DAMSM_optimizer is None else self.DAMSM_optimizer.state_dict(), "netG": self.netG_optimizer.state_dict(), }, "scaler": self.scaler.state_dict(), "steps_cnt": self.scheduler.current_steps, "settings": self.opt, } for i in range(self.stage_num): netD_name = 'netD_' + str(64 * 2**i) save_dict["models"][netD_name] = get_state_dict(self.netsD[i]) save_dict["optims"][netD_name] = self.netD_optimizers[ i].state_dict() torch.save(save_dict, f"{self.save_model_dir}/{model_name}")
def save_model(self, cnt, model_name): save_dict = { "cnt": cnt, "image_encoder": get_state_dict(self.image_encoder), "text_encoder": get_state_dict(self.text_encoder), "image_optimizer": self.image_optimizer.state_dict(), "text_optimizer": self.text_optimizer.state_dict(), "scaler": self.scaler.state_dict(), "settings": self.opt, } torch.save(save_dict, f"{self.opt.save_model_dir}/{model_name}")
def train_by_epoch(self, start_epoch=1): self.device = get_device(self.model) self.model.train() start_all = time() for epoch_cnt in range(start_epoch, self.opt.max_epoch + 1): self.logger.info(f"\n[ Epoch {epoch_cnt} ]") start_span = time() avg_epoch_loss = self._train_epoch() time_span = (time() - start_span) / 60 self.logger.info( f"word_loss : {avg_epoch_loss:.2f}, time : {time_span:.2f} min" ) if self.validator is not None: state_dict = self.validation(epoch_cnt) else: state_dict = get_state_dict(self.model) if epoch_cnt > self.opt.max_epoch / 3: self.save_model(epoch_cnt, state_dict, f"epoch_{epoch_cnt}.pth") time_all = (time() - start_all) / 3600 self.logger.info( f"\nbest_epoch : {self.best_cnt}, best_score : {self.best_bleu_score}, time : {time_all:.2f} h" )
def save_model(self, cnt, model_name): backup_para = copy_params(self.netG) load_params(self.netG, self.avg_param_G) save_dict = { "cnt": cnt, "netG": get_state_dict(self.netG), "optimG": self.netG_optimizer.state_dict(), "scaler": self.scaler.state_dict(), "settings": self.opt, } for i in range(self.stage_num): netD_name = "netD_" + str(64 * 2**i) optimD_name = "optimD_" + str(64 * 2**i) save_dict[netD_name] = get_state_dict(self.netsD[i]) save_dict[optimD_name] = self.netD_optimizers[i].state_dict() torch.save(save_dict, f"{self.save_model_dir}/{model_name}") load_params(self.netG, backup_para)
def _get_cp_avg_bleu(self, use_beam=False, return_sentences=False): back_up_params = copy_params(self.model) avg_params = self._cp_avg() load_params(self.model, avg_params) pred_words = self.generator.generate_loader(self.data_loader, use_beam) state_dict = get_state_dict(self.model) load_params(self.model, back_up_params) return pred_words, state_dict
def train(self, start_cnt): self.device = get_device(self.MNMT) start_all = time() for epoch_cnt in range(start_cnt, self.opt.max_epoch + 1): self.logger.info(f"\n[ Epoch {epoch_cnt} ]") # --- train MNMT --- start_span = time() backup_para = copy_params(self.netG) load_params(self.netG, self.avg_param_G) logs = self.MNMT_train_epoch() time_span = (time() - start_span) / 60 self.logger.info(f"{logs}, time : {time_span:.2f} min") # --- valid MNMT --- if self.validator is not None: state_dict = self.validation(epoch_cnt) else: state_dict = get_state_dict(self.MNMT) if self.stop_cnt == self.opt.early_stop: break self.save_models(epoch_cnt, state_dict, f"epoch_{epoch_cnt}.pth") # --- train T2I --- start_span = time() load_params(self.netG, backup_para) for _ in range(self.opt.T2I_per_MNMT): D_logs, G_logs = self.T2I_train_epoch() time_span = (time() - start_span) / 60 self.logger.info(f"{D_logs}\n{G_logs}\ntime : {time_span:.2f} min") time_all = (time() - start_all) / 3600 self.logger.info( f"\nbest_epoch : {self.best_cnt}, best_score : {self.best_bleu_score}, time : {time_all:.2f} h" )
def _get_no_cp_avg_bleu(self, use_beam=False): pred_words = self.generator.generate_loader(self.data_loader, use_beam) state_dict = get_state_dict(self.model) return pred_words, state_dict
def train_by_step(self, start_step=1): self.device = get_device(self.model) self.model.train() self.optimizer.zero_grad() data_iter = iter(self.train_loader) checkpoint_interval = 1500 checkpoint_cnt = 1 checkpoint_loss = 0. step_cnt = start_step iter_num = (self.opt.max_step - start_step + 1) * self.opt.grad_accumulation start_all = time() start_span = time() pbar = tqdm(range(iter_num), ncols=90, mininterval=0.5, ascii=True) for _ in pbar: try: train_datas = data_iter.next() except StopIteration: data_iter = iter(self.train_loader) train_datas = data_iter.next() with autocast(self.opt.use_amp): loss, batch_size = self.cal_loss(*train_datas) checkpoint_loss += loss.item() loss /= self.opt.grad_accumulation self.scaler.scale(loss).backward() if checkpoint_cnt % self.opt.grad_accumulation == 0: self.scaler.unscale_(self.optimizer) clip_grad_norm_(self.model.parameters(), self.opt.max_norm) self.scheduler.update_lr() self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad() avg_checkpoint_loss = checkpoint_loss / checkpoint_cnt pbar.set_description( f"\r[{step_cnt}/{self.opt.max_step}] " \ f"word_loss : {avg_checkpoint_loss:.2f}, batch_size : {batch_size:<5}" ) if step_cnt % checkpoint_interval == 0: print() time_span = (time() - start_span) / 60 self.logger.info(f"\n[ Step {step_cnt} ]") self.logger.info( f"word_loss : {avg_checkpoint_loss:.2f}, time : {time_span:.2f} min" ) if self.validator is not None: state_dict = self.validation(step_cnt) else: state_dict = get_state_dict(self.model) if step_cnt > self.opt.max_step / 3: self.save_model(step_cnt, state_dict, f"step_{step_cnt}.pth") start_span = time() checkpoint_cnt = 0 checkpoint_loss = 0. print() step_cnt += 1 checkpoint_cnt += 1 step_cnt = step_cnt - 1 if step_cnt % checkpoint_interval != 0: self.logger.info(f"\n[ Step {step_cnt} ]") self.logger.info(f"word_loss : {avg_checkpoint_loss:.2f}") if self.validator is not None: state_dict = self.validation(step_cnt) else: state_dict = get_state_dict(self.model) self.save_model(step_cnt, state_dict, f"step_{step_cnt}.pth") time_all = (time() - start_all) / 3600 self.logger.info( f"\nbest_step : {self.best_cnt}, best_score : {self.best_bleu_score}, time : {time_all:.2f} h" )