Exemple #1
0
    def load(self, load_optim: bool = True):
        # make name
        save_name = self.save_name

        # save path
        save_path = os.path.join(self.model_dir, save_name)

        # get latest file
        check_files = glob.glob(os.path.join(save_path, '*'))
        if check_files:
            # load latest state dict
            latest_file = max(check_files, key=os.path.getctime)
            state_dict = torch.load(latest_file)
            if 'seed' in state_dict:
                self.seed = state_dict['seed']
            # load model
            if isinstance(self.model, nn.DataParallel):
                self.model.module.load_state_dict(
                    get_loadable_checkpoint(state_dict['model']))
            else:
                self.model.load_state_dict(
                    get_loadable_checkpoint(state_dict['model']))
            if load_optim:
                self.optimizer.load_state_dict(state_dict['optim'])
            if self.scheduler is not None:
                self.scheduler.load_state_dict(state_dict['scheduler'])
            self.step = state_dict['step']
            log('checkpoint \'{}\' is loaded. previous step={}'.format(
                latest_file, self.step))
        else:
            log('No any checkpoint in {}. Loading network skipped.'.format(
                save_path))
def __load_model(model_name: str, pretrained_path: str) -> torch.nn.Module:
    print('Load model ...')
    model = build_model(model_name).cuda()
    chk = torch.load(pretrained_path)['model']
    model.load_state_dict(get_loadable_checkpoint(chk))
    model.eval()
    return model
Exemple #3
0
    def save(self, step: int):

        # state dict
        state_dict = get_loadable_checkpoint(self.model.state_dict())

        # train
        state_dict = {
            'step': step,
            'model': state_dict,
            'optim': self.optimizer.state_dict(),
            'pretrained_step': step,
            'seed': self.seed
        }
        if self.scheduler is not None:
            state_dict.update({'scheduler': self.scheduler.state_dict()})

        # save for training
        save_name = self.save_name

        save_path = os.path.join(self.model_dir, save_name)
        os.makedirs(save_path, exist_ok=True)
        torch.save(state_dict,
                   os.path.join(save_path, 'step_{:06d}.chkpt'.format(step)))

        # save best
        if self.best_valid_loss != self.cur_best_valid_loss:
            save_path = os.path.join(self.model_dir, save_name + '.best.chkpt')
            torch.save(state_dict, save_path)
            self.cur_best_valid_loss = self.best_valid_loss

        # logging
        log('step %d / saved model.' % step)
Exemple #4
0
def __load_model(model_name: str, pretrained_path: str) -> torch.nn.Module:
    print('Load model ...')
    model = build_model(model_name)
    chk = torch.load(pretrained_path,
                     map_location=torch.device('cpu'))['model']
    model.load_state_dict(get_loadable_checkpoint(chk))
    model.eval()
    return model
Exemple #5
0
 def load_pretrained_model(self):
     assert os.path.exists(
         self.pretrained_trained), 'You must define pretrained path!'
     self.model.load_state_dict(
         get_loadable_checkpoint(
             torch.load(self.pretrained_trained)['model']))