Ejemplo n.º 1
0
 def save_models(self, cnt, state_dict, model_name):
     save_dict = {
         "cnt": cnt,
         "models": {
             "MNMT":
             state_dict,
             "tgt_DAMSM":
             None if self.tgt_DAMSM_RNN is None else get_state_dict(
                 self.tgt_DAMSM_RNN),
             "netG":
             get_state_dict(self.netG),
         },
         "optims": {
             "MNMT":
             self.MNMT_optimizer.state_dict(),
             "tgt_DAMSM":
             None if self.DAMSM_optimizer is None else
             self.DAMSM_optimizer.state_dict(),
             "netG":
             self.netG_optimizer.state_dict(),
         },
         "scaler": self.scaler.state_dict(),
         "steps_cnt": self.scheduler.current_steps,
         "settings": self.opt,
     }
     for i in range(self.stage_num):
         netD_name = 'netD_' + str(64 * 2**i)
         save_dict["models"][netD_name] = get_state_dict(self.netsD[i])
         save_dict["optims"][netD_name] = self.netD_optimizers[
             i].state_dict()
     torch.save(save_dict, f"{self.save_model_dir}/{model_name}")
Ejemplo n.º 2
0
 def save_model(self, cnt, model_name):
     save_dict = {
         "cnt": cnt,
         "image_encoder": get_state_dict(self.image_encoder),
         "text_encoder": get_state_dict(self.text_encoder),
         "image_optimizer": self.image_optimizer.state_dict(),
         "text_optimizer": self.text_optimizer.state_dict(),
         "scaler": self.scaler.state_dict(),
         "settings": self.opt,
     }
     torch.save(save_dict, f"{self.opt.save_model_dir}/{model_name}")
Ejemplo n.º 3
0
    def train_by_epoch(self, start_epoch=1):
        self.device = get_device(self.model)
        self.model.train()

        start_all = time()
        for epoch_cnt in range(start_epoch, self.opt.max_epoch + 1):
            self.logger.info(f"\n[ Epoch {epoch_cnt} ]")

            start_span = time()
            avg_epoch_loss = self._train_epoch()
            time_span = (time() - start_span) / 60
            self.logger.info(
                f"word_loss : {avg_epoch_loss:.2f}, time : {time_span:.2f} min"
            )

            if self.validator is not None:
                state_dict = self.validation(epoch_cnt)
            else:
                state_dict = get_state_dict(self.model)

            if epoch_cnt > self.opt.max_epoch / 3:
                self.save_model(epoch_cnt, state_dict,
                                f"epoch_{epoch_cnt}.pth")

        time_all = (time() - start_all) / 3600
        self.logger.info(
            f"\nbest_epoch : {self.best_cnt}, best_score : {self.best_bleu_score}, time : {time_all:.2f} h"
        )
Ejemplo n.º 4
0
 def save_model(self, cnt, model_name):
     backup_para = copy_params(self.netG)
     load_params(self.netG, self.avg_param_G)
     save_dict = {
         "cnt": cnt,
         "netG": get_state_dict(self.netG),
         "optimG": self.netG_optimizer.state_dict(),
         "scaler": self.scaler.state_dict(),
         "settings": self.opt,
     }
     for i in range(self.stage_num):
         netD_name = "netD_" + str(64 * 2**i)
         optimD_name = "optimD_" + str(64 * 2**i)
         save_dict[netD_name] = get_state_dict(self.netsD[i])
         save_dict[optimD_name] = self.netD_optimizers[i].state_dict()
     torch.save(save_dict, f"{self.save_model_dir}/{model_name}")
     load_params(self.netG, backup_para)
 def _get_cp_avg_bleu(self, use_beam=False, return_sentences=False):
     back_up_params = copy_params(self.model)
     avg_params = self._cp_avg()
     load_params(self.model, avg_params)
     pred_words = self.generator.generate_loader(self.data_loader, use_beam)
     state_dict = get_state_dict(self.model)
     load_params(self.model, back_up_params)
     return pred_words, state_dict
Ejemplo n.º 6
0
    def train(self, start_cnt):
        self.device = get_device(self.MNMT)

        start_all = time()
        for epoch_cnt in range(start_cnt, self.opt.max_epoch + 1):
            self.logger.info(f"\n[ Epoch {epoch_cnt} ]")

            # --- train MNMT ---
            start_span = time()
            backup_para = copy_params(self.netG)

            load_params(self.netG, self.avg_param_G)
            logs = self.MNMT_train_epoch()
            time_span = (time() - start_span) / 60
            self.logger.info(f"{logs}, time : {time_span:.2f} min")

            # --- valid MNMT ---
            if self.validator is not None:
                state_dict = self.validation(epoch_cnt)
            else:
                state_dict = get_state_dict(self.MNMT)
            if self.stop_cnt == self.opt.early_stop:
                break
            self.save_models(epoch_cnt, state_dict, f"epoch_{epoch_cnt}.pth")

            # --- train T2I ---
            start_span = time()
            load_params(self.netG, backup_para)
            for _ in range(self.opt.T2I_per_MNMT):
                D_logs, G_logs = self.T2I_train_epoch()
            time_span = (time() - start_span) / 60
            self.logger.info(f"{D_logs}\n{G_logs}\ntime : {time_span:.2f} min")

        time_all = (time() - start_all) / 3600
        self.logger.info(
            f"\nbest_epoch : {self.best_cnt}, best_score : {self.best_bleu_score}, time : {time_all:.2f} h"
        )
 def _get_no_cp_avg_bleu(self, use_beam=False):
     pred_words = self.generator.generate_loader(self.data_loader, use_beam)
     state_dict = get_state_dict(self.model)
     return pred_words, state_dict
Ejemplo n.º 8
0
    def train_by_step(self, start_step=1):
        self.device = get_device(self.model)

        self.model.train()
        self.optimizer.zero_grad()
        data_iter = iter(self.train_loader)
        checkpoint_interval = 1500
        checkpoint_cnt = 1
        checkpoint_loss = 0.

        step_cnt = start_step
        iter_num = (self.opt.max_step - start_step +
                    1) * self.opt.grad_accumulation
        start_all = time()
        start_span = time()
        pbar = tqdm(range(iter_num), ncols=90, mininterval=0.5, ascii=True)
        for _ in pbar:
            try:
                train_datas = data_iter.next()
            except StopIteration:
                data_iter = iter(self.train_loader)
                train_datas = data_iter.next()

            with autocast(self.opt.use_amp):
                loss, batch_size = self.cal_loss(*train_datas)
            checkpoint_loss += loss.item()
            loss /= self.opt.grad_accumulation

            self.scaler.scale(loss).backward()

            if checkpoint_cnt % self.opt.grad_accumulation == 0:
                self.scaler.unscale_(self.optimizer)
                clip_grad_norm_(self.model.parameters(), self.opt.max_norm)
                self.scheduler.update_lr()
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.optimizer.zero_grad()

                avg_checkpoint_loss = checkpoint_loss / checkpoint_cnt
                pbar.set_description(
                    f"\r[{step_cnt}/{self.opt.max_step}] " \
                    f"word_loss : {avg_checkpoint_loss:.2f}, batch_size : {batch_size:<5}"
                )

                if step_cnt % checkpoint_interval == 0:
                    print()
                    time_span = (time() - start_span) / 60
                    self.logger.info(f"\n[ Step {step_cnt} ]")
                    self.logger.info(
                        f"word_loss : {avg_checkpoint_loss:.2f}, time : {time_span:.2f} min"
                    )

                    if self.validator is not None:
                        state_dict = self.validation(step_cnt)
                    else:
                        state_dict = get_state_dict(self.model)

                    if step_cnt > self.opt.max_step / 3:
                        self.save_model(step_cnt, state_dict,
                                        f"step_{step_cnt}.pth")

                    start_span = time()
                    checkpoint_cnt = 0
                    checkpoint_loss = 0.
                    print()

                step_cnt += 1
            checkpoint_cnt += 1

        step_cnt = step_cnt - 1
        if step_cnt % checkpoint_interval != 0:
            self.logger.info(f"\n[ Step {step_cnt} ]")
            self.logger.info(f"word_loss : {avg_checkpoint_loss:.2f}")
            if self.validator is not None:
                state_dict = self.validation(step_cnt)
            else:
                state_dict = get_state_dict(self.model)
            self.save_model(step_cnt, state_dict, f"step_{step_cnt}.pth")

        time_all = (time() - start_all) / 3600
        self.logger.info(
            f"\nbest_step : {self.best_cnt}, best_score : {self.best_bleu_score}, time : {time_all:.2f} h"
        )