def load_ckpt(self, model_path): """ load training checkpoint for further training or predicting. Args: model_path: the path of saved checkpoint/parameters. """ assert self._train_init_prog is not None or self._pred_init_prog is not None, "model graph not built. You should at least build_forward or build_predict_forward to load its checkpoint." if self._train_init_prog is not None: saver.init_pretraining_params( self._exe, model_path, convert=False, main_program=self._train_init_prog, strict=True) elif self._pred_init_prog is not None: saver.init_pretraining_params( self._exe, model_path, convert=False, main_program=self._pred_init_prog, strict=True) else: raise Exception("model not found. You should at least build_forward or build_predict_forward to load its checkpoint.")
def load_pretrain(self, pretrain_model_path=None): # load pretrain model (or ckpt) if pretrain_model_path is None: assert 'pretrain_model_path' in self.main_conf, "pretrain_model_path NOT set." pretrain_model_path = self.main_conf['pretrain_model_path'] init_pretraining_params(self.exe, pretrain_model_path, main_program=fluid.default_startup_program())
def load_pretrain(self, model_path, convert=False): """ load pretrain models(backbone) for training. Args: model_path: the path of saved pretrained parameters. """ assert self._train_init_prog is not None, "training graph not found. You should at least build_forward to load its pretrained parameters." saver.init_pretraining_params(self._exe, model_path, convert=convert, main_program=self._train_init_prog)
def load_ckpt(self, model_path): """ load training checkpoint for further training or predicting. Args: model_path: the path of saved checkpoint/parameters. """ # load pretrain model (or ckpt) # assert self._exe is not None, "You need to random_init_params before load checkpoints." # if phase == 'train' and not self._train_init: # self._init_exe_prog(for_train=True) # self._exe.run(self._train_init_prog) # if phase == 'predict' and not self._predict_init: # self._init_exe_prog(for_train=False) # self._exe.run(self._pred_init_prog) assert self._train_init_prog is not None or self._pred_init_prog is not None, "model graph not built. You should at least build_forward or build_predict_forward to load its checkpoint." # if phase == 'train': # assert self._train_init_prog is not None, "train graph not found! You should build_forward first before load checkpoint." if self._train_init_prog is not None: saver.init_pretraining_params(self._exe, model_path, convert=False, main_program=self._train_init_prog, strict=True) # elif phase == 'predict': elif self._pred_init_prog is not None: # assert self._pred_init_prog is not None, "predict graph not found! You should build_predict_head first before load checkpoint." saver.init_pretraining_params(self._exe, model_path, convert=False, main_program=self._pred_init_prog, strict=True) else: raise Exception( "model not found. You should at least build_forward or build_predict_forward to load its checkpoint." )