示例#1
0
文件: trainer.py 项目: yaweisun/PALM
    def load_ckpt(self, model_path):
        """
        load training checkpoint for further training or predicting.

        Args:
            model_path: the path of saved checkpoint/parameters.
        """
        assert self._train_init_prog is not None or self._pred_init_prog is not None, "model graph not built. You should at least build_forward or build_predict_forward to load its checkpoint."

        if self._train_init_prog is not None:
            saver.init_pretraining_params(
                self._exe,
                model_path,
                convert=False,
                main_program=self._train_init_prog,
                strict=True)
        elif self._pred_init_prog is not None:
            saver.init_pretraining_params(
                self._exe,
                model_path,
                convert=False,
                main_program=self._pred_init_prog,
                strict=True)
        else:
            raise Exception("model not found. You should at least build_forward or build_predict_forward to load its checkpoint.")
示例#2
0
    def load_pretrain(self, pretrain_model_path=None):
        # load pretrain model (or ckpt)
        if pretrain_model_path is None:
            assert 'pretrain_model_path' in self.main_conf, "pretrain_model_path NOT set."
            pretrain_model_path = self.main_conf['pretrain_model_path']

        init_pretraining_params(self.exe,
                                pretrain_model_path,
                                main_program=fluid.default_startup_program())
示例#3
0
文件: trainer.py 项目: wuhuaha/padd
    def load_pretrain(self, model_path, convert=False):
        """
        load pretrain models(backbone) for training.

        Args:
            model_path: the path of saved pretrained parameters.
        """
        assert self._train_init_prog is not None, "training graph not found. You should at least build_forward to load its pretrained parameters."

        saver.init_pretraining_params(self._exe,
                                      model_path,
                                      convert=convert,
                                      main_program=self._train_init_prog)
    def load_ckpt(self, model_path):
        """
        load training checkpoint for further training or predicting.

        Args:
            model_path: the path of saved checkpoint/parameters.
        """
        # load pretrain model (or ckpt)
        # assert self._exe is not None, "You need to random_init_params before load checkpoints."
        # if phase == 'train' and not self._train_init:
        #     self._init_exe_prog(for_train=True)
        #     self._exe.run(self._train_init_prog)
        # if phase == 'predict' and not self._predict_init:
        #     self._init_exe_prog(for_train=False)
        #     self._exe.run(self._pred_init_prog)

        assert self._train_init_prog is not None or self._pred_init_prog is not None, "model graph not built. You should at least build_forward or build_predict_forward to load its checkpoint."

        # if phase == 'train':
        #     assert self._train_init_prog is not None, "train graph not found! You should build_forward first before load checkpoint."
        if self._train_init_prog is not None:
            saver.init_pretraining_params(self._exe,
                                          model_path,
                                          convert=False,
                                          main_program=self._train_init_prog,
                                          strict=True)
        # elif phase == 'predict':
        elif self._pred_init_prog is not None:
            # assert self._pred_init_prog is not None, "predict graph not found! You should build_predict_head first before load checkpoint."
            saver.init_pretraining_params(self._exe,
                                          model_path,
                                          convert=False,
                                          main_program=self._pred_init_prog,
                                          strict=True)
        else:
            raise Exception(
                "model not found. You should at least build_forward or build_predict_forward to load its checkpoint."
            )