Ejemplo n.º 1
0
    def train_process(self):
        """Whole train and validate process for the fully train cyclesr."""
        # self._init_all_settings()
        init_log(level=General.logger.level,
                 log_file="log_worker_{}.txt".format(self.trainer.worker_id),
                 log_path=self.trainer.local_log_path)
        self._init_report()
        if self.cfg.cuda:
            self.trainer._init_cuda_setting()
        self.model = self._init_model()
        if self.cfg.distributed:
            self._horovod_init_optimizer()
            self._init_horovod_setting()
        self.train_data = self._init_dataloader('train')
        self.valid_data = self._init_dataloader('test')
        train_dataloader = Adapter(self.train_data).loader
        valid_dataloader = Adapter(self.valid_data).loader

        writer = SummaryWriter(self.worker_path)

        start_time = time.time()
        train_time = 0
        best_psnr = -np.inf
        best_epoch = 0
        logging.info("==> Start training")
        val_gan_imgs = self.getValImg(self.train_data, val_num=5)
        for epoch in range(self.cfg.epoch_count, self.cfg.n_epoch + self.cfg.n_epoch_decay + 1):
            self.model.update_learning_rate(
                epoch,
                self.cfg.model_desc.custom.cyc_lr,
                self.cfg.model_desc.custom.SR_lr,
                self.cfg.n_epoch,
                self.cfg.n_epoch_decay)
            start_train_time = time.time()
            self._train(train_dataloader, writer, epoch, self.model, print_freq=self.cfg.print_freq)
            train_time += round(time.time() - start_train_time)
            # validation
            ###############################################################################
            if epoch % self.cfg.eval_epoch == 0:
                logging.info("==> Validng")
                self._evalGAN(self.model, val_gan_imgs, epoch, writer)
                val_ave_psnr = self._valid(self.model, valid_dataloader, epoch, self.cfg.eval_epoch, writer,
                                           self.cfg.val_ps_offset)
                if val_ave_psnr is not None:
                    logging.info("==> Current ave psnr is {:.3f}".format(val_ave_psnr))
                    if val_ave_psnr > best_psnr:
                        best_psnr = val_ave_psnr
                        best_epoch = epoch
                    logging.info(
                        "==> Best PSNR on val dataset {:.3f}, achieved at epoch {}".format(best_psnr, best_epoch))
                    self._save_checkpoint(epoch, best=True)
                    self._broadcast(epoch, {"psnr": val_ave_psnr})
                model_name = 'epoch' + str(epoch)
                logging.info("Saving checkpoints to {}".format(model_name))
                self._save_checkpoint(epoch)
        elapsed = round(time.time() - start_time)
        elapsed = str(datetime.timedelta(seconds=elapsed))
        train_time = str(datetime.timedelta(seconds=train_time))
        logging.info("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
Ejemplo n.º 2
0
 def _init_dataloader(self, mode, loader=None, transforms=None):
     """Init dataloader."""
     if loader is not None:
         return loader
     if mode == "train" and self.hps is not None and self.hps.get(
             "dataset") is not None:
         if self.hps.get("dataset") and self.hps.get("dataset").get('type'):
             dataset_cls = ClassFactory.get_cls(
                 ClassType.DATASET,
                 self.hps.get("dataset").get('type'))
         else:
             dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
         dataset = dataset_cls(mode=mode, hps=self.hps.get("dataset"))
     elif self.hps:
         if self.hps.get("dataset") and self.hps.get("dataset").get('type'):
             dataset_cls = ClassFactory.get_cls(
                 ClassType.DATASET,
                 self.hps.get("dataset").get('type'))
             dataset = dataset_cls(mode=mode, hps=self.hps.get("dataset"))
         else:
             dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
             dataset = dataset_cls(mode=mode)
     else:
         dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
         dataset = dataset_cls(mode=mode)
     if transforms is not None:
         dataset.transforms = transforms
     if self.distributed and mode == "train":
         dataset.set_distributed(self._world_size, self._rank_id)
     # adapt the dataset to specific backend
     dataloader = Adapter(dataset).loader
     return dataloader
Ejemplo n.º 3
0
 def _init_dataloader(self, mode, loader=None):
     """Init dataloader."""
     if loader is not None:
         return loader
     dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
     dataset = dataset_cls(mode=mode)
     dataloader = Adapter(dataset).loader
     return dataloader
Ejemplo n.º 4
0
 def __init__(self):
     super(LatencyFilter, self).__init__()
     self.max_latency = self.restrict_config.latency
     if self.max_latency is not None:
         dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
         self.dataset = dataset_cls()
         from zeus.datasets import Adapter
         self.dataloader = Adapter(self.dataset).loader
Ejemplo n.º 5
0
 def __init__(self):
     super(FlopsParamsFilter, self).__init__()
     self.flops_range = self.restrict_config.flops
     self.params_range = self.restrict_config.params
     if self.flops_range and not isinstance(self.flops_range, list):
         self.flops_range = [0., self.flops_range]
     if self.params_range and not isinstance(self.params_range, list):
         self.params_range = [0., self.params_range]
     if self.flops_range is not None or self.params_range is not None:
         dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
         self.dataset = dataset_cls()
         from zeus.datasets import Adapter
         self.dataloader = Adapter(self.dataset).loader
Ejemplo n.º 6
0
    def is_filtered(self, desc=None):
        """Filter function of latency."""
        try:
            if not self.dataloader:
                dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
                self.dataset = dataset_cls()
                from zeus.datasets import Adapter
                self.dataloader = Adapter(self.dataset).loader

            model, count_input = self.get_model_input(desc)
            model(count_input)
            return False
        except Exception:
            encoding = desc['backbone']['encoding']
            logging.info('Invalid encoding: {}'.format(encoding))
            return True