class Predictor: def __init__(self, data_dir, ckp_path, save_lr=False, list_dir='', out_dir='./', log_dir=''): super(Predictor, self).__init__() self.data_dir = data_dir self.list_dir = list_dir self.out_dir = out_dir self.checkpoint = ckp_path self.save_lr = save_lr self.logger = Logger(scrn=True, log_dir=log_dir, phase='test') self.model = None def test_epoch(self): raise NotImplementedError def test(self): if self.checkpoint: if self._resume_from_checkpoint(): self.model.cuda() self.model.eval() self.test_epoch() else: self.logger.warning("no checkpoint assigned!") def _resume_from_checkpoint(self): if not os.path.isfile(self.checkpoint): self.logger.error("=> no checkpoint found at '{}'".format( self.checkpoint)) return False self.logger.show("=> loading checkpoint '{}'".format(self.checkpoint)) checkpoint = torch.load(self.checkpoint) state_dict = self.model.state_dict() ckp_dict = checkpoint.get('state_dict', checkpoint) try: state_dict.update(ckp_dict) self.model.load_state_dict(ckp_dict) except KeyError as e: self.logger.error("=> mismatched checkpoint for test") self.logger.error(e) return False else: self.epoch = checkpoint.get('epoch', 0) self.logger.show("=> loaded checkpoint '{}'".format(self.checkpoint)) return True
class Trainer: def __init__(self, settings): super(Trainer, self).__init__() self.settings = settings self.phase = settings.cmd self.batch_size = settings.batch_size self.data_dir = settings.data_dir self.list_dir = settings.list_dir self.checkpoint = settings.resume self.load_checkpoint = (len(self.checkpoint) > 0) self.num_epochs = settings.num_epochs self.lr = float(settings.lr) self.save = settings.save_on or settings.out_dir self.from_pause = self.settings.continu self.path_ctrl = settings.global_path self.path = self.path_ctrl.get_path log_dir = '' if settings.log_off else self.path_ctrl.get_dir('log') self.logger = Logger(scrn=True, log_dir=log_dir, phase=self.phase) for k, v in sorted(settings.__dict__.items()): self.logger.show("{}: {}".format(k, v)) self.start_epoch = 0 self._init_max_acc = 0.0 self.model = None self.criterion = None def train_epoch(self): raise NotImplementedError def validate_epoch(self, epoch, store): raise NotImplementedError def train(self): cudnn.benchmark = True if self.load_checkpoint: self._resume_from_checkpoint() max_acc = self._init_max_acc best_epoch = self.get_ckp_epoch() self.model.cuda() self.criterion.cuda() end_epoch = self.num_epochs if self.from_pause else self.start_epoch + self.num_epochs for epoch in range(self.start_epoch, end_epoch): lr = self._adjust_learning_rate(epoch) self.logger.show_nl("Epoch: [{0}]\tlr {1:.06f}".format(epoch, lr)) # Train for one epoch self.train_epoch() # Evaluate the model on validation set self.logger.show_nl("Validate") acc = self.validate_epoch(epoch=epoch, store=self.save) is_best = acc > max_acc if is_best: max_acc = acc best_epoch = epoch self.logger.show_nl( "Current: {:.6f} ({:03d})\tBest: {:.6f} ({:03d})\t".format( acc, epoch, max_acc, best_epoch)) # The checkpoint saves next epoch self._save_checkpoint(self.model.state_dict(), max_acc, epoch + 1, is_best) def validate(self): if self.checkpoint: if self._resume_from_checkpoint(): self.model.cuda() self.criterion.cuda() self.validate_epoch(self.get_ckp_epoch(), self.save) else: self.logger.warning("no checkpoint assigned!") def _load_pretrained(self): raise NotImplementedError def _adjust_learning_rate(self, epoch): # Note that this does not take effect for separate learning rates start_epoch = 0 if self.from_pause else self.start_epoch if self.settings.lr_mode == 'step': lr = self.lr * (0.5**((epoch - start_epoch) // self.settings.step)) elif self.settings.lr_mode == 'poly': lr = self.lr * (1 - (epoch - start_epoch) / (self.num_epochs - start_epoch))**1.1 elif self.settings.lr_mode == 'const': lr = self.lr else: raise ValueError('unknown lr mode {}'.format( self.settings.lr_mode)) if lr == self.lr: return self.lr for param_group in self.optimizer.param_groups: param_group['lr'] = lr return lr def _resume_from_checkpoint(self): if not os.path.isfile(self.checkpoint): self.logger.error("=> no checkpoint found at '{}'".format( self.checkpoint)) return False self.logger.show("=> loading checkpoint '{}'".format(self.checkpoint)) checkpoint = torch.load(self.checkpoint) state_dict = self.model.state_dict() ckp_dict = checkpoint.get('state_dict', checkpoint) update_dict = { k: v for k, v in ckp_dict.items() if k in state_dict and state_dict[k].shape == v.shape } num_to_update = len(update_dict) if (num_to_update < len(state_dict)) or (len(state_dict) < len(ckp_dict)): if self.phase == 'val': self.logger.error("=> mismatched checkpoint for validation") return False self.logger.warning( "warning: trying to load an mismatched checkpoint") if num_to_update == 0: self.logger.error("=> no parameter is to be loaded") return False else: self.logger.warning( "=> {} params are to be loaded".format(num_to_update)) elif (not self.settings.anew) or (self.phase != 'train'): # Note in the non-anew mode, it is not guaranteed that the contained field # max_acc be the corresponding one of the loaded checkpoint. self.start_epoch = checkpoint.get('epoch', self.start_epoch) self._init_max_acc = checkpoint.get('max_acc', self._init_max_acc) state_dict.update(update_dict) self.model.load_state_dict(state_dict) self.logger.show( "=> loaded checkpoint '{}' (epoch {}, max_acc {:.4f})".format( self.checkpoint, self.get_ckp_epoch(), self._init_max_acc)) return True def _save_checkpoint(self, state_dict, max_acc, epoch, is_best): state = {'epoch': epoch, 'state_dict': state_dict, 'max_acc': max_acc} # Save history history_path = self.path('weight', CKP_COUNTED.format(e=epoch, s=self.scale), underline=True) if (epoch - self.start_epoch) % self.settings.trace_freq == 0: torch.save(state, history_path) # Save latest latest_path = self.path('weight', CKP_LATEST.format(s=self.scale), underline=True) torch.save(state, latest_path) if is_best: shutil.copyfile( latest_path, self.path('weight', CKP_BEST.format(s=self.scale), underline=True)) def get_ckp_epoch(self): # Get current epoch of the checkpoint # For dismatched ckp or no ckp, set to 0 return max(self.start_epoch - 1, 0)