def train_process(self): """Whole train and validate process for the fully train cyclesr.""" # self._init_all_settings() init_log(level=General.logger.level, log_file="log_worker_{}.txt".format(self.trainer.worker_id), log_path=self.trainer.local_log_path) self._init_report() if self.cfg.cuda: self.trainer._init_cuda_setting() self.model = self._init_model() if self.cfg.distributed: self._horovod_init_optimizer() self._init_horovod_setting() self.train_data = self._init_dataloader('train') self.valid_data = self._init_dataloader('test') train_dataloader = Adapter(self.train_data).loader valid_dataloader = Adapter(self.valid_data).loader writer = SummaryWriter(self.worker_path) start_time = time.time() train_time = 0 best_psnr = -np.inf best_epoch = 0 logging.info("==> Start training") val_gan_imgs = self.getValImg(self.train_data, val_num=5) for epoch in range(self.cfg.epoch_count, self.cfg.n_epoch + self.cfg.n_epoch_decay + 1): self.model.update_learning_rate( epoch, self.cfg.model_desc.custom.cyc_lr, self.cfg.model_desc.custom.SR_lr, self.cfg.n_epoch, self.cfg.n_epoch_decay) start_train_time = time.time() self._train(train_dataloader, writer, epoch, self.model, print_freq=self.cfg.print_freq) train_time += round(time.time() - start_train_time) # validation ############################################################################### if epoch % self.cfg.eval_epoch == 0: logging.info("==> Validng") self._evalGAN(self.model, val_gan_imgs, epoch, writer) val_ave_psnr = self._valid(self.model, valid_dataloader, epoch, self.cfg.eval_epoch, writer, self.cfg.val_ps_offset) if val_ave_psnr is not None: logging.info("==> Current ave psnr is {:.3f}".format(val_ave_psnr)) if val_ave_psnr > best_psnr: best_psnr = val_ave_psnr best_epoch = epoch logging.info( "==> Best PSNR on val dataset {:.3f}, achieved at epoch {}".format(best_psnr, best_epoch)) self._save_checkpoint(epoch, best=True) self._broadcast(epoch, {"psnr": val_ave_psnr}) model_name = 'epoch' + str(epoch) logging.info("Saving checkpoints to {}".format(model_name)) self._save_checkpoint(epoch) elapsed = round(time.time() - start_time) elapsed = str(datetime.timedelta(seconds=elapsed)) train_time = str(datetime.timedelta(seconds=train_time)) logging.info("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
def init_trainer(self): """Init Train Op.""" init_log(level=General.logger.level, log_file="log_worker_{}.txt".format(self.worker_id), log_path=self.local_log_path) self._set_default_funcs() self._set_condition() self._init_callbacks() self.callbacks.init_trainer() self.init_train_op()
def train_process(self): """Validate process for the model validate worker.""" init_log(level=General.logger.level, log_file="host_evaluator_{}.log".format(self.worker_id), log_path=self.local_log_path) logging.info("start evaluate process") self.load_model() self.valid_loader = self._init_dataloader(mode='test') performance = self.valid(self.valid_loader) self._broadcast(performance) logging.info("the model (id {}) is evaluated on the host".format(self.worker_id))
def train_process(self): """Whole train process of the TrainWorker specified in config. After training, the model and validation results are saved to local_worker_path and s3_path. """ init_log(level=General.logger.level, log_file="log_worker_{}.txt".format(self.worker_id), log_path=self.local_log_path) self._set_default_funcs() self._set_condition() self._init_callbacks() self.callbacks.init_trainer() if not self.lazy_built: self.build() self._train_loop()
def _init_env(): if sys.version_info < (3, 6): sys.exit('Sorry, Python < 3.6 is not supported.') init_log(level=General.logger.level, log_path=TaskOps().local_log_path) General.env = init_cluster_args() _print_task_id()