def train(self): timer = utils.Timer() self.best_cvloss = 9e20 if self.cv_loss: self.best_cvloss = min(self.cv_loss) while self.epoch < self.num_epoch: timer.tic() self.epoch += 1 logging.info("Training") tr_loss = self.iter_one_epoch() tr_msg = ("tr loss: {:.4f}").format(tr_loss) msg = "\n" + "-"*85 + "\n" msg += "Epoch {} Training Summary:\n{}\n".format(self.epoch, tr_msg) msg += "-"*85 logging.info(msg) self.save(os.path.join(self.exp_dir, "ep-{:04d}.pt".format(self.epoch))) self.save(os.path.join(self.exp_dir, "last.pt")) logging.info("Validation") cv_loss = self.iter_one_epoch(cross_valid=True) if self.best_cvloss > cv_loss: self.best_cvloss = cv_loss train_time = timer.toc() cv_msg = ("cv loss: {:.4f} | best cv loss {:.4f}").format(cv_loss, self.best_cvloss) msg = "\n" + "-"*85 + "\n" msg += "Epoch {} Validation Summary:\n{}\n".format(self.epoch, cv_msg) msg += "Time cost: {:.4f} min".format(train_time/60.) msg += "\n" + "-"*85 + '\n' logging.info(msg) self.tr_loss.append(tr_loss) self.cv_loss.append(cv_loss) if self.num_last_ckpt_keep: utils.cleanup_ckpt(self.exp_dir, self.num_last_ckpt_keep)
def train(self): timer = utils.Timer() self.best_cvloss = 9e20 if self.cv_loss: self.best_cvloss = min(self.cv_loss) if utils.TENSORBOARD_LOGGING == 1 and self.config["vis_atten"]: self.visualize_figure() while self.epoch < self.num_epoch: timer.tic() self.epoch += 1 logging.info("Training") tr_loss = self.iter_one_epoch() tr_msg = ("tr loss: {:.4f}").format(tr_loss) tr_msg += ", tr ppl {:.4f}".format(np.exp(tr_loss)) msg = "\n" + "-" * 85 + "\n" msg += "Epoch {} Training Summary:\n{}\n".format( self.epoch, tr_msg) msg += "-" * 85 logging.info(msg) self.save( os.path.join(self.exp_dir, "ep-{:04d}.pt".format(self.epoch))) self.save(os.path.join(self.exp_dir, "last-ckpt.pt")) logging.info("Validation") cv_loss = self.iter_one_epoch(cross_valid=True) if self.best_cvloss > cv_loss: self.best_cvloss = cv_loss train_time = timer.toc() cv_msg = ("cv loss: {:.4f} | best cv loss {:.4f} | ").format( cv_loss, self.best_cvloss) cv_msg += ("cv ppl: {:.4f} | best cv ppl {:.4f} | ").format( np.exp(cv_loss), np.exp(self.best_cvloss)) msg = "\n" + "-" * 85 + "\n" msg += "Epoch {} Validation Summary:\n{}\n".format( self.epoch, cv_msg) msg += "Time cost: {:.4f} min".format(train_time / 60.) msg += "\n" + "-" * 85 + '\n' logging.info(msg) if isinstance(self.lr_scheduler, schedule.BobLearningRateSchedule): self.lr_scheduler.update_decay_rate(np.exp(cv_loss)) self.tr_loss.append(tr_loss) self.cv_loss.append(cv_loss) if utils.TENSORBOARD_LOGGING == 1: utils.visualizer.add_scalar("tr_loss/loss", tr_loss, self.epoch) utils.visualizer.add_scalar("cv_loss/loss", cv_loss, self.epoch) utils.visualizer.add_scalar("tr_loss/ppl", np.exp(tr_loss), self.epoch) utils.visualizer.add_scalar("cv_loss/ppl", np.exp(cv_loss), self.epoch) if self.num_last_ckpt_keep: utils.cleanup_ckpt(self.exp_dir, self.num_last_ckpt_keep)
def test_cleanup(): expdir = "testdata/cleanup" os.makedirs(expdir) for i in range(120): with open(os.path.join(expdir, "ckpt-{:04d}.pt".format(i)), 'w') as f: f.write("") with open(os.path.join(expdir, "last-ckpt.pt"), 'w') as f: f.write("") utils.cleanup_ckpt(expdir, 3) len(os.listdir(expdir)) == 4