def load(self, load_optim: bool = True): # make name save_name = self.save_name # save path save_path = os.path.join(self.model_dir, save_name) # get latest file check_files = glob.glob(os.path.join(save_path, '*')) if check_files: # load latest state dict latest_file = max(check_files, key=os.path.getctime) state_dict = torch.load(latest_file) if 'seed' in state_dict: self.seed = state_dict['seed'] # load model if isinstance(self.model, nn.DataParallel): self.model.module.load_state_dict( get_loadable_checkpoint(state_dict['model'])) else: self.model.load_state_dict( get_loadable_checkpoint(state_dict['model'])) if load_optim: self.optimizer.load_state_dict(state_dict['optim']) if self.scheduler is not None: self.scheduler.load_state_dict(state_dict['scheduler']) self.step = state_dict['step'] log('checkpoint \'{}\' is loaded. previous step={}'.format( latest_file, self.step)) else: log('No any checkpoint in {}. Loading network skipped.'.format( save_path))
def __load_model(model_name: str, pretrained_path: str) -> torch.nn.Module: print('Load model ...') model = build_model(model_name).cuda() chk = torch.load(pretrained_path)['model'] model.load_state_dict(get_loadable_checkpoint(chk)) model.eval() return model
def save(self, step: int): # state dict state_dict = get_loadable_checkpoint(self.model.state_dict()) # train state_dict = { 'step': step, 'model': state_dict, 'optim': self.optimizer.state_dict(), 'pretrained_step': step, 'seed': self.seed } if self.scheduler is not None: state_dict.update({'scheduler': self.scheduler.state_dict()}) # save for training save_name = self.save_name save_path = os.path.join(self.model_dir, save_name) os.makedirs(save_path, exist_ok=True) torch.save(state_dict, os.path.join(save_path, 'step_{:06d}.chkpt'.format(step))) # save best if self.best_valid_loss != self.cur_best_valid_loss: save_path = os.path.join(self.model_dir, save_name + '.best.chkpt') torch.save(state_dict, save_path) self.cur_best_valid_loss = self.best_valid_loss # logging log('step %d / saved model.' % step)
def __load_model(model_name: str, pretrained_path: str) -> torch.nn.Module: print('Load model ...') model = build_model(model_name) chk = torch.load(pretrained_path, map_location=torch.device('cpu'))['model'] model.load_state_dict(get_loadable_checkpoint(chk)) model.eval() return model
def load_pretrained_model(self): assert os.path.exists( self.pretrained_trained), 'You must define pretrained path!' self.model.load_state_dict( get_loadable_checkpoint( torch.load(self.pretrained_trained)['model']))