Exemplo n.º 1
0
    def train(self):
        timer = utils.Timer()
        self.best_cvloss = 9e20
        if self.cv_loss:
            self.best_cvloss = min(self.cv_loss)

        while self.epoch < self.num_epoch:
            timer.tic()
            self.epoch += 1
            logging.info("Training")
            tr_loss = self.iter_one_epoch()
            tr_msg = ("tr loss: {:.4f}").format(tr_loss)
            msg = "\n" + "-"*85 + "\n"
            msg += "Epoch {} Training Summary:\n{}\n".format(self.epoch, tr_msg)
            msg += "-"*85
            logging.info(msg)
            self.save(os.path.join(self.exp_dir, "ep-{:04d}.pt".format(self.epoch)))
            self.save(os.path.join(self.exp_dir, "last.pt"))
            logging.info("Validation")
            cv_loss = self.iter_one_epoch(cross_valid=True)

            if self.best_cvloss > cv_loss:
                self.best_cvloss = cv_loss
            train_time = timer.toc()
            cv_msg = ("cv loss: {:.4f} | best cv loss {:.4f}").format(cv_loss, self.best_cvloss)
            msg = "\n" + "-"*85 + "\n"
            msg += "Epoch {} Validation Summary:\n{}\n".format(self.epoch, cv_msg)
            msg += "Time cost: {:.4f} min".format(train_time/60.)
            msg += "\n" + "-"*85 + '\n'
            logging.info(msg)
            self.tr_loss.append(tr_loss)
            self.cv_loss.append(cv_loss)

            if self.num_last_ckpt_keep:
                utils.cleanup_ckpt(self.exp_dir, self.num_last_ckpt_keep)
Exemplo n.º 2
0
    def train(self):
        timer = utils.Timer()
        self.best_cvloss = 9e20
        if self.cv_loss:
            self.best_cvloss = min(self.cv_loss)

        if utils.TENSORBOARD_LOGGING == 1 and self.config["vis_atten"]:
            self.visualize_figure()

        while self.epoch < self.num_epoch:
            timer.tic()
            self.epoch += 1
            logging.info("Training")
            tr_loss = self.iter_one_epoch()
            tr_msg = ("tr loss: {:.4f}").format(tr_loss)
            tr_msg += ", tr ppl {:.4f}".format(np.exp(tr_loss))
            msg = "\n" + "-" * 85 + "\n"
            msg += "Epoch {} Training Summary:\n{}\n".format(
                self.epoch, tr_msg)
            msg += "-" * 85
            logging.info(msg)
            self.save(
                os.path.join(self.exp_dir, "ep-{:04d}.pt".format(self.epoch)))
            self.save(os.path.join(self.exp_dir, "last-ckpt.pt"))
            logging.info("Validation")
            cv_loss = self.iter_one_epoch(cross_valid=True)
            if self.best_cvloss > cv_loss:
                self.best_cvloss = cv_loss
            train_time = timer.toc()
            cv_msg = ("cv loss: {:.4f} | best cv loss {:.4f} | ").format(
                cv_loss, self.best_cvloss)
            cv_msg += ("cv ppl: {:.4f} | best cv ppl {:.4f} | ").format(
                np.exp(cv_loss), np.exp(self.best_cvloss))
            msg = "\n" + "-" * 85 + "\n"
            msg += "Epoch {} Validation Summary:\n{}\n".format(
                self.epoch, cv_msg)
            msg += "Time cost: {:.4f} min".format(train_time / 60.)
            msg += "\n" + "-" * 85 + '\n'
            logging.info(msg)
            if isinstance(self.lr_scheduler, schedule.BobLearningRateSchedule):
                self.lr_scheduler.update_decay_rate(np.exp(cv_loss))
            self.tr_loss.append(tr_loss)
            self.cv_loss.append(cv_loss)

            if utils.TENSORBOARD_LOGGING == 1:
                utils.visualizer.add_scalar("tr_loss/loss", tr_loss,
                                            self.epoch)
                utils.visualizer.add_scalar("cv_loss/loss", cv_loss,
                                            self.epoch)
                utils.visualizer.add_scalar("tr_loss/ppl", np.exp(tr_loss),
                                            self.epoch)
                utils.visualizer.add_scalar("cv_loss/ppl", np.exp(cv_loss),
                                            self.epoch)

            if self.num_last_ckpt_keep:
                utils.cleanup_ckpt(self.exp_dir, self.num_last_ckpt_keep)
Exemplo n.º 3
0
def test_cleanup():
    expdir = "testdata/cleanup"
    os.makedirs(expdir)
    for i in range(120):
        with open(os.path.join(expdir, "ckpt-{:04d}.pt".format(i)), 'w') as f:
            f.write("")
    with open(os.path.join(expdir, "last-ckpt.pt"), 'w') as f:
        f.write("")
    utils.cleanup_ckpt(expdir, 3)
    len(os.listdir(expdir)) == 4