Esempio n. 1
0
    def train_process(self):
        """Whole train and validate process for the fully train cyclesr."""
        # self._init_all_settings()
        init_log(log_file="worker_{}.txt".format(self.worker_id))
        if self.cfg.cuda:
            self._init_cuda_setting()
        self.model = self._init_model()
        if self.horovod:
            self._horovod_init_optimizer()
            self._init_horovod_setting()
        self.train_data = self._init_dataloader('train')
        self.valid_data = self._init_dataloader('test')
        train_dataloader = self.train_data.dataloader
        valid_dataloader = self.valid_data.dataloader

        writer = SummaryWriter(self.get_local_worker_path())

        start_time = time.time()
        train_time = 0
        best_psnr = -np.inf
        best_epoch = 0
        logging.info("==> Start training")
        val_gan_imgs = self.getValImg(self.train_data, val_num=5)
        for epoch in range(self.cfg.epoch_count, self.cfg.n_epoch + self.cfg.n_epoch_decay + 1):
            self.model.update_learning_rate(
                epoch,
                self.cfg.model_desc.custom.cyc_lr,
                self.cfg.model_desc.custom.SR_lr,
                self.cfg.n_epoch,
                self.cfg.n_epoch_decay)
            start_train_time = time.time()
            self._train(train_dataloader, writer, epoch, self.model, print_freq=self.cfg.print_freq)
            train_time += round(time.time() - start_train_time)
            # validation
            ###############################################################################
            if (epoch % self.cfg.eval_epoch == 0):
                logging.info("==> Validng")
                self._evalGAN(self.model, val_gan_imgs, epoch, writer)
                val_ave_psnr = self._valid(self.model, valid_dataloader, epoch, self.cfg.eval_epoch, writer,
                                           self.cfg.val_ps_offset)
                if val_ave_psnr is not None:
                    logging.info("==> Current ave psnr is {:.3f}".format(val_ave_psnr))
                    if val_ave_psnr > best_psnr:
                        best_psnr = val_ave_psnr
                        best_epoch = epoch
                    logging.info(
                        "==> Best PSNR on val dataset {:.3f}, achieved at epoch {}".format(best_psnr, best_epoch))
                    best_model_name = 'model_best'
                    self._save_checkpoint(best_model_name)
                model_name = 'epoch' + str(epoch)
                logging.info("Saving checkpoints to {}".format(model_name))
                self._save_checkpoint(model_name)
                self._backup()
        elapsed = round(time.time() - start_time)
        elapsed = str(datetime.timedelta(seconds=elapsed))
        train_time = str(datetime.timedelta(seconds=train_time))
        logging.info("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
        self._backup()
        logging.info("***** Move Done! *****")
Esempio n. 2
0
 def train_process(self):
     """Validate process for the model validate worker."""
     init_log(log_file="gpu_eva_{}.txt".format(self.worker_id))
     logging.info("start evaluate process")
     self._init_all_settings()
     performance = self.valid(self.valid_loader)
     self._save_performance(performance)
     logging.info("finished evaluate for id {}".format(self.worker_id))
     self.evaluate_result = performance
     return
Esempio n. 3
0
    def train_process(self):
        """Whole train process of the TrainWorker specified in config.

        After training, the model and validation results are saved to local_worker_path and s3_path.
        """
        init_log(log_file="worker_{}.txt".format(self.worker_id))
        logging.debug("Use the unified Trainer")
        if not self.lazy_built:
            self.build(model=self.model,
                       hps=self.hps,
                       load_ckpt_flag=self.load_ckpt_flag)
        self.train()
Esempio n. 4
0
 def train_process(self):
     """Validate process for the model validate worker."""
     init_log(log_file="gpu_eva_{}.txt".format(self.worker_id))
     logging.info("start evaluate process")
     if self.model_desc and self.weights_file:
         self.model = ModelZoo.get_model(self.model_desc, self.weights_file)
     elif self._flag_load_checkpoint:
         self.load_checkpoint(saved_folder=self.saved_folder,
                              step_name=self.saved_step_name)
     else:
         self._load_pretrained_model()
     self.valid_loader = self._init_dataloader(mode='test')
     performance = self.valid(self.valid_loader)
     self._broadcast(performance)
     logging.info("finished evaluate for id {}".format(self.worker_id))