示例#1
0
    def train_process(self):
        """Whole train and validate process for the fully train cyclesr."""
        # self._init_all_settings()
        init_log(level=General.logger.level,
                 log_file="log_worker_{}.txt".format(self.trainer.worker_id),
                 log_path=self.trainer.local_log_path)
        self._init_report()
        if self.cfg.cuda:
            self.trainer._init_cuda_setting()
        self.model = self._init_model()
        if self.cfg.distributed:
            self._horovod_init_optimizer()
            self._init_horovod_setting()
        self.train_data = self._init_dataloader('train')
        self.valid_data = self._init_dataloader('test')
        train_dataloader = Adapter(self.train_data).loader
        valid_dataloader = Adapter(self.valid_data).loader

        writer = SummaryWriter(self.worker_path)

        start_time = time.time()
        train_time = 0
        best_psnr = -np.inf
        best_epoch = 0
        logging.info("==> Start training")
        val_gan_imgs = self.getValImg(self.train_data, val_num=5)
        for epoch in range(self.cfg.epoch_count, self.cfg.n_epoch + self.cfg.n_epoch_decay + 1):
            self.model.update_learning_rate(
                epoch,
                self.cfg.model_desc.custom.cyc_lr,
                self.cfg.model_desc.custom.SR_lr,
                self.cfg.n_epoch,
                self.cfg.n_epoch_decay)
            start_train_time = time.time()
            self._train(train_dataloader, writer, epoch, self.model, print_freq=self.cfg.print_freq)
            train_time += round(time.time() - start_train_time)
            # validation
            ###############################################################################
            if epoch % self.cfg.eval_epoch == 0:
                logging.info("==> Validng")
                self._evalGAN(self.model, val_gan_imgs, epoch, writer)
                val_ave_psnr = self._valid(self.model, valid_dataloader, epoch, self.cfg.eval_epoch, writer,
                                           self.cfg.val_ps_offset)
                if val_ave_psnr is not None:
                    logging.info("==> Current ave psnr is {:.3f}".format(val_ave_psnr))
                    if val_ave_psnr > best_psnr:
                        best_psnr = val_ave_psnr
                        best_epoch = epoch
                    logging.info(
                        "==> Best PSNR on val dataset {:.3f}, achieved at epoch {}".format(best_psnr, best_epoch))
                    self._save_checkpoint(epoch, best=True)
                    self._broadcast(epoch, {"psnr": val_ave_psnr})
                model_name = 'epoch' + str(epoch)
                logging.info("Saving checkpoints to {}".format(model_name))
                self._save_checkpoint(epoch)
        elapsed = round(time.time() - start_time)
        elapsed = str(datetime.timedelta(seconds=elapsed))
        train_time = str(datetime.timedelta(seconds=train_time))
        logging.info("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
示例#2
0
    def init_trainer(self):
        """Init Train Op."""
        init_log(level=General.logger.level,
                 log_file="log_worker_{}.txt".format(self.worker_id),
                 log_path=self.local_log_path)
        self._set_default_funcs()
        self._set_condition()
        self._init_callbacks()
        self.callbacks.init_trainer()

        self.init_train_op()
示例#3
0
 def train_process(self):
     """Validate process for the model validate worker."""
     init_log(level=General.logger.level,
              log_file="host_evaluator_{}.log".format(self.worker_id),
              log_path=self.local_log_path)
     logging.info("start evaluate process")
     self.load_model()
     self.valid_loader = self._init_dataloader(mode='test')
     performance = self.valid(self.valid_loader)
     self._broadcast(performance)
     logging.info("the model (id {}) is evaluated on the host".format(self.worker_id))
示例#4
0
    def train_process(self):
        """Whole train process of the TrainWorker specified in config.

        After training, the model and validation results are saved to local_worker_path and s3_path.
        """
        init_log(level=General.logger.level,
                 log_file="log_worker_{}.txt".format(self.worker_id),
                 log_path=self.local_log_path)
        self._set_default_funcs()
        self._set_condition()
        self._init_callbacks()
        self.callbacks.init_trainer()
        if not self.lazy_built:
            self.build()
        self._train_loop()
示例#5
0
def _init_env():
    if sys.version_info < (3, 6):
        sys.exit('Sorry, Python < 3.6 is not supported.')
    init_log(level=General.logger.level, log_path=TaskOps().local_log_path)
    General.env = init_cluster_args()
    _print_task_id()