class Predictor: def __init__(self, data_dir, ckp_path, save_lr=False, list_dir='', out_dir='./', log_dir=''): super(Predictor, self).__init__() self.data_dir = data_dir self.list_dir = list_dir self.out_dir = out_dir self.checkpoint = ckp_path self.save_lr = save_lr self.logger = Logger(scrn=True, log_dir=log_dir, phase='test') self.model = None def test_epoch(self): raise NotImplementedError def test(self): if self.checkpoint: if self._resume_from_checkpoint(): self.model.cuda() self.model.eval() self.test_epoch() else: self.logger.warning("no checkpoint assigned!") def _resume_from_checkpoint(self): if not os.path.isfile(self.checkpoint): self.logger.error("=> no checkpoint found at '{}'".format( self.checkpoint)) return False self.logger.show("=> loading checkpoint '{}'".format(self.checkpoint)) checkpoint = torch.load(self.checkpoint) state_dict = self.model.state_dict() ckp_dict = checkpoint.get('state_dict', checkpoint) try: state_dict.update(ckp_dict) self.model.load_state_dict(ckp_dict) except KeyError as e: self.logger.error("=> mismatched checkpoint for test") self.logger.error(e) return False else: self.epoch = checkpoint.get('epoch', 0) self.logger.show("=> loaded checkpoint '{}'".format(self.checkpoint)) return True
class Predictor: modes = ['dataloader', 'folder', 'list', 'file', 'data'] def __init__(self, model=None, mode='folder', save_dir=None, scrn=True, log_dir=None, cuda_off=False): self.save_dir = save_dir self.output = None if not cuda_off and torch.cuda.is_available(): self.device = torch.device('cuda') else: self.device = torch.device('cpu') assert model is not None, "The model must be assigned" self.model = self._model_init(model) if mode not in Predictor.modes: raise NotImplementedError self.logger = Logger(scrn=scrn, log_dir=log_dir, phase='predict') if mode == 'dataloader': self._predict = partial(self._predict_dataloader, dataloader=None, save_dir=save_dir) elif mode == 'folder': # self.suffix = ['.jpg', '.png', '.bmp', '.gif', '.npy'] # 支持的图像格式 self._predict = partial(self._predict_folder, save_dir=save_dir) elif mode == 'list': self._predict = partial(self._predict_list, save_dir=save_dir) elif mode == 'file': self._predict = partial(self._predict_file, save_dir=save_dir) elif mode == 'data': self._predict = partial(self._predict_data, save_dir=save_dir) else: raise NotImplementedError def __call__(self, *args, **kwargs): return self._predict(*args, **kwargs) def _model_init(self, model): model.to(self.device) model.eval() return model def _load_data(self, path): return io.imread(path) def _to_tensor(self, arr): return to_tensor(arr) def _to_array(self, tensor): return to_array(tensor) def _normalize(self, tensor): return normalize(tensor) def _np2tensor(self, arr): nor_tensor = self._normalize(self._to_tensor(arr)) assert isinstance(nor_tensor, torch.Tensor) return nor_tensor def _save_data_NTIRE2020(self, data, path): s_dir = os.path.dirname(path) if not os.path.exists(s_dir): os.mkdir(s_dir) path = path.replace('_clean.png', '.mat').replace('_RealWorld.png', '.mat') if isinstance(data, torch.Tensor): data = self._to_array(data).squeeze() content = {} content['cube'] = data content['bands'] = np.array([[ 400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520, 530, 540, 550, 560, 570, 580, 590, 600, 610, 620, 630, 640, 650, 660, 670, 680, 690, 700 ]]) # content['norm_factor'] = hdf5.write(data=content, filename=path, store_python_metadata=True, matlab_compatible=True) def _save_data(self, data, path): s_dir = os.path.dirname(path) if not os.path.exists(s_dir): os.mkdir(s_dir) torchvision.utils.save_image(data, path) def predict_base(self, model, data, path=None): start = time.time() with torch.no_grad(): output = model(data) torch.cuda.synchronize() su_time = time.time() - start if path: self._save_data_NTIRE2020(output, path) self.output = output return output, su_time def _predict_dataloader(self, dataloader, save_dir=None): assert dataloader is not None, \ "In 'dataloader' mode the input must be a valid dataloader!" consume_time = AverageMeter() pb = tqdm(dataloader) for idx, (name, data) in enumerate(pb): assert isinstance(data, torch.Tensor) and data.dim() == 4,\ "input data must be 4-dimention tensor" data = data.to(self.device) # 4-d tensor save_path = os.path.join(save_dir, name) if save_dir else None _, su_time = self.predict_base(self.model, data, path=save_path) consume_time.update(su_time, n=1) # logger description = ( "[{}/{}] speed: {time.val:.4f}s({time.avg:.4f}s)".format( idx + 1, len(dataloader.dataset), time=consume_time)) pb.set_description(description) self.logger.dump(description) def _predict_folder(self, folder, save_dir=None): assert folder is not None and os.path.isdir(folder),\ "In 'folder' mode the input must be a valid path of a folder!" consume_time = AverageMeter() file_list = glob.glob(os.path.join(folder, '*')) assert not len(file_list) == 0, "The input folder is empty" pb = tqdm(file_list) # processbar for idx, file in enumerate(pb): img = self._load_data(file) name = os.path.basename(file) img = self._np2tensor(img).unsqueeze(0).to(self.device) save_path = os.path.join(save_dir, name) if save_dir else None _, su_time = self.predict_base(model=self.model, data=img, path=save_path) consume_time.update(su_time) # logger description = ( "[{}/{}] speed: {time.val:.4f}s({time.avg:.4f}s)".format( idx + 1, len(file_list), time=consume_time)) pb.set_description(description) self.logger.dump(description) def _predict_list(self, file_list, save_dir=None): assert isinstance(file_list, list),\ "In 'list' mode the input must be a valid file_path list!" consume_time = AverageMeter() assert not len(file_list) == 0, "The input file list is empty!" pb = tqdm(file_list) # processbar for idx, path in enumerate(pb): data = self._load_data(path) name = os.path.basename(path) data = self._np2tensor(data).unsqueeze(0).to(self.device) path = os.path.join(save_dir, name) if save_dir else None _, su_time = self.predict_base(model=self.model, data=data, path=path) consume_time.update(su_time, n=1) # logger description = ( "[{}/{}] speed: {time.val:.4f}s({time.avg:.4f}s)".format( idx + 1, len(file_list), time=consume_time)) pb.set_description(description) self.logger.dump(description) def _predict_file(self, file_path, save_dir=None): assert isinstance(file_path, str) and os.path.isfile(file_path), \ "In 'file' mode the input must a valid path of a file!" consume_time = AverageMeter() data = self._load_data(file_path) name = os.path.basename(file_path) data = self._np2tensor(data).unsqueeze(0).to(self.device) path = os.path.join(save_dir, name) if save_dir else None _, su_time = self.predict_base(model=self.model, data=data, path=path) consume_time.update(su_time) # logger description = ("file: {} speed: {time.val:.4f}s".format( name, time=consume_time)) self.logger.show(description) def _predict_data(self, data): """ :return: tensor """ assert isinstance(data, torch.Tensor) and data.dim() == 4, \ "In 'data' mode the input must be a 4-d tensor" consume_time = AverageMeter() output, su_time = self.predict_base(model=self.model, data=data) consume_time.update(su_time) # logger description = ("speed: {time.val:.4f}s".format(time=consume_time)) self.logger.dump(description) return output
class Trainer: def __init__(self, settings): super(Trainer, self).__init__() self.settings = settings self.phase = settings.cmd self.batch_size = settings.batch_size self.data_dir = settings.data_dir self.list_dir = settings.list_dir self.checkpoint = settings.resume self.load_checkpoint = (len(self.checkpoint) > 0) self.num_epochs = settings.num_epochs self.lr = float(settings.lr) self.save = settings.save_on or settings.out_dir self.from_pause = self.settings.continu self.path_ctrl = settings.global_path self.path = self.path_ctrl.get_path log_dir = '' if settings.log_off else self.path_ctrl.get_dir('log') self.logger = Logger(scrn=True, log_dir=log_dir, phase=self.phase) for k, v in sorted(settings.__dict__.items()): self.logger.show("{}: {}".format(k, v)) self.start_epoch = 0 self._init_max_acc = 0.0 self.model = None self.criterion = None def train_epoch(self): raise NotImplementedError def validate_epoch(self, epoch, store): raise NotImplementedError def train(self): cudnn.benchmark = True if self.load_checkpoint: self._resume_from_checkpoint() max_acc = self._init_max_acc best_epoch = self.get_ckp_epoch() self.model.cuda() self.criterion.cuda() end_epoch = self.num_epochs if self.from_pause else self.start_epoch + self.num_epochs for epoch in range(self.start_epoch, end_epoch): lr = self._adjust_learning_rate(epoch) self.logger.show_nl("Epoch: [{0}]\tlr {1:.06f}".format(epoch, lr)) # Train for one epoch self.train_epoch() # Evaluate the model on validation set self.logger.show_nl("Validate") acc = self.validate_epoch(epoch=epoch, store=self.save) is_best = acc > max_acc if is_best: max_acc = acc best_epoch = epoch self.logger.show_nl( "Current: {:.6f} ({:03d})\tBest: {:.6f} ({:03d})\t".format( acc, epoch, max_acc, best_epoch)) # The checkpoint saves next epoch self._save_checkpoint(self.model.state_dict(), max_acc, epoch + 1, is_best) def validate(self): if self.checkpoint: if self._resume_from_checkpoint(): self.model.cuda() self.criterion.cuda() self.validate_epoch(self.get_ckp_epoch(), self.save) else: self.logger.warning("no checkpoint assigned!") def _load_pretrained(self): raise NotImplementedError def _adjust_learning_rate(self, epoch): # Note that this does not take effect for separate learning rates start_epoch = 0 if self.from_pause else self.start_epoch if self.settings.lr_mode == 'step': lr = self.lr * (0.5**((epoch - start_epoch) // self.settings.step)) elif self.settings.lr_mode == 'poly': lr = self.lr * (1 - (epoch - start_epoch) / (self.num_epochs - start_epoch))**1.1 elif self.settings.lr_mode == 'const': lr = self.lr else: raise ValueError('unknown lr mode {}'.format( self.settings.lr_mode)) if lr == self.lr: return self.lr for param_group in self.optimizer.param_groups: param_group['lr'] = lr return lr def _resume_from_checkpoint(self): if not os.path.isfile(self.checkpoint): self.logger.error("=> no checkpoint found at '{}'".format( self.checkpoint)) return False self.logger.show("=> loading checkpoint '{}'".format(self.checkpoint)) checkpoint = torch.load(self.checkpoint) state_dict = self.model.state_dict() ckp_dict = checkpoint.get('state_dict', checkpoint) update_dict = { k: v for k, v in ckp_dict.items() if k in state_dict and state_dict[k].shape == v.shape } num_to_update = len(update_dict) if (num_to_update < len(state_dict)) or (len(state_dict) < len(ckp_dict)): if self.phase == 'val': self.logger.error("=> mismatched checkpoint for validation") return False self.logger.warning( "warning: trying to load an mismatched checkpoint") if num_to_update == 0: self.logger.error("=> no parameter is to be loaded") return False else: self.logger.warning( "=> {} params are to be loaded".format(num_to_update)) elif (not self.settings.anew) or (self.phase != 'train'): # Note in the non-anew mode, it is not guaranteed that the contained field # max_acc be the corresponding one of the loaded checkpoint. self.start_epoch = checkpoint.get('epoch', self.start_epoch) self._init_max_acc = checkpoint.get('max_acc', self._init_max_acc) state_dict.update(update_dict) self.model.load_state_dict(state_dict) self.logger.show( "=> loaded checkpoint '{}' (epoch {}, max_acc {:.4f})".format( self.checkpoint, self.get_ckp_epoch(), self._init_max_acc)) return True def _save_checkpoint(self, state_dict, max_acc, epoch, is_best): state = {'epoch': epoch, 'state_dict': state_dict, 'max_acc': max_acc} # Save history history_path = self.path('weight', CKP_COUNTED.format(e=epoch, s=self.scale), underline=True) if (epoch - self.start_epoch) % self.settings.trace_freq == 0: torch.save(state, history_path) # Save latest latest_path = self.path('weight', CKP_LATEST.format(s=self.scale), underline=True) torch.save(state, latest_path) if is_best: shutil.copyfile( latest_path, self.path('weight', CKP_BEST.format(s=self.scale), underline=True)) def get_ckp_epoch(self): # Get current epoch of the checkpoint # For dismatched ckp or no ckp, set to 0 return max(self.start_epoch - 1, 0)