def train_process(self): """Whole train and validate process for the fully train cyclesr.""" # self._init_all_settings() init_log(level=General.logger.level, log_file="log_worker_{}.txt".format(self.trainer.worker_id), log_path=self.trainer.local_log_path) self._init_report() if self.cfg.cuda: self.trainer._init_cuda_setting() self.model = self._init_model() if self.cfg.distributed: self._horovod_init_optimizer() self._init_horovod_setting() self.train_data = self._init_dataloader('train') self.valid_data = self._init_dataloader('test') train_dataloader = Adapter(self.train_data).loader valid_dataloader = Adapter(self.valid_data).loader writer = SummaryWriter(self.worker_path) start_time = time.time() train_time = 0 best_psnr = -np.inf best_epoch = 0 logging.info("==> Start training") val_gan_imgs = self.getValImg(self.train_data, val_num=5) for epoch in range(self.cfg.epoch_count, self.cfg.n_epoch + self.cfg.n_epoch_decay + 1): self.model.update_learning_rate( epoch, self.cfg.model_desc.custom.cyc_lr, self.cfg.model_desc.custom.SR_lr, self.cfg.n_epoch, self.cfg.n_epoch_decay) start_train_time = time.time() self._train(train_dataloader, writer, epoch, self.model, print_freq=self.cfg.print_freq) train_time += round(time.time() - start_train_time) # validation ############################################################################### if epoch % self.cfg.eval_epoch == 0: logging.info("==> Validng") self._evalGAN(self.model, val_gan_imgs, epoch, writer) val_ave_psnr = self._valid(self.model, valid_dataloader, epoch, self.cfg.eval_epoch, writer, self.cfg.val_ps_offset) if val_ave_psnr is not None: logging.info("==> Current ave psnr is {:.3f}".format(val_ave_psnr)) if val_ave_psnr > best_psnr: best_psnr = val_ave_psnr best_epoch = epoch logging.info( "==> Best PSNR on val dataset {:.3f}, achieved at epoch {}".format(best_psnr, best_epoch)) self._save_checkpoint(epoch, best=True) self._broadcast(epoch, {"psnr": val_ave_psnr}) model_name = 'epoch' + str(epoch) logging.info("Saving checkpoints to {}".format(model_name)) self._save_checkpoint(epoch) elapsed = round(time.time() - start_time) elapsed = str(datetime.timedelta(seconds=elapsed)) train_time = str(datetime.timedelta(seconds=train_time)) logging.info("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
def _init_dataloader(self, mode, loader=None, transforms=None): """Init dataloader.""" if loader is not None: return loader if mode == "train" and self.hps is not None and self.hps.get( "dataset") is not None: if self.hps.get("dataset") and self.hps.get("dataset").get('type'): dataset_cls = ClassFactory.get_cls( ClassType.DATASET, self.hps.get("dataset").get('type')) else: dataset_cls = ClassFactory.get_cls(ClassType.DATASET) dataset = dataset_cls(mode=mode, hps=self.hps.get("dataset")) elif self.hps: if self.hps.get("dataset") and self.hps.get("dataset").get('type'): dataset_cls = ClassFactory.get_cls( ClassType.DATASET, self.hps.get("dataset").get('type')) dataset = dataset_cls(mode=mode, hps=self.hps.get("dataset")) else: dataset_cls = ClassFactory.get_cls(ClassType.DATASET) dataset = dataset_cls(mode=mode) else: dataset_cls = ClassFactory.get_cls(ClassType.DATASET) dataset = dataset_cls(mode=mode) if transforms is not None: dataset.transforms = transforms if self.distributed and mode == "train": dataset.set_distributed(self._world_size, self._rank_id) # adapt the dataset to specific backend dataloader = Adapter(dataset).loader return dataloader
def _init_dataloader(self, mode, loader=None): """Init dataloader.""" if loader is not None: return loader dataset_cls = ClassFactory.get_cls(ClassType.DATASET) dataset = dataset_cls(mode=mode) dataloader = Adapter(dataset).loader return dataloader
def __init__(self): super(LatencyFilter, self).__init__() self.max_latency = self.restrict_config.latency if self.max_latency is not None: dataset_cls = ClassFactory.get_cls(ClassType.DATASET) self.dataset = dataset_cls() from zeus.datasets import Adapter self.dataloader = Adapter(self.dataset).loader
def __init__(self): super(FlopsParamsFilter, self).__init__() self.flops_range = self.restrict_config.flops self.params_range = self.restrict_config.params if self.flops_range and not isinstance(self.flops_range, list): self.flops_range = [0., self.flops_range] if self.params_range and not isinstance(self.params_range, list): self.params_range = [0., self.params_range] if self.flops_range is not None or self.params_range is not None: dataset_cls = ClassFactory.get_cls(ClassType.DATASET) self.dataset = dataset_cls() from zeus.datasets import Adapter self.dataloader = Adapter(self.dataset).loader
def is_filtered(self, desc=None): """Filter function of latency.""" try: if not self.dataloader: dataset_cls = ClassFactory.get_cls(ClassType.DATASET) self.dataset = dataset_cls() from zeus.datasets import Adapter self.dataloader = Adapter(self.dataset).loader model, count_input = self.get_model_input(desc) model(count_input) return False except Exception: encoding = desc['backbone']['encoding'] logging.info('Invalid encoding: {}'.format(encoding)) return True