예제 #1
0
    def test_train_bert_model(self):
        preprocessor = SPMPreprocessor(
            (self.train_data[0] + self.valid_data[0],
             self.train_data[1] + self.valid_data[1]),
            self.train_labels + self.valid_labels,
            use_word=False,
            use_char=False,
            use_bert=True,
            use_bert_model=True,
            bert_vocab_file=self.bert_vocab_file,
            max_len=10)
        spm_model = BertSPM(num_class=self.num_class,
                            bert_config_file=self.bert_config_file,
                            bert_checkpoint_file=self.bert_model_file,
                            bert_trainable=True,
                            max_len=preprocessor.max_len).build_model()

        spm_trainer = SPMTrainer(spm_model, preprocessor)
        spm_trainer.train(self.train_data,
                          self.train_labels,
                          self.valid_data,
                          self.valid_labels,
                          batch_size=6,
                          epochs=2)
        assert not os.path.exists(self.json_file)
        assert not os.path.exists(self.weights_file)
예제 #2
0
파일: spm.py 프로젝트: zouxiaoshi/fancy-nlp
    def load(self,
             preprocessor_file: str,
             json_file: str,
             weights_file: str,
             custom_objects: Optional[Dict[str, Any]] = None) -> None:
        """load spm application

        Args:
            preprocessor_file: path to load preprocessor
            json_file: path to load model architecture
            weights_file: path to load model weights
            custom_objects: Optional dictionary mapping names (strings) to custom classes or
                            functions to be considered during deserialization. Must provided when
                            using custom layer.

        """
        self.preprocessor = SPMPreprocessor.load(preprocessor_file)
        logging.info('Load preprocessor from {}'.format(preprocessor_file))

        custom_objects = custom_objects or {}
        custom_objects.update(get_custom_objects())
        with open(json_file, 'r') as reader:
            self.model = tf.keras.models.model_from_json(
                reader.read(), custom_objects=custom_objects)
        logging.info('Load model architecture from {}'.format(json_file))

        self.model.load_weights(weights_file)
        logging.info('Load model weight from {}'.format(weights_file))

        self.trainer = SPMTrainer(self.model, self.preprocessor)
        self.predictor = SPMPredictor(self.model, self.preprocessor)
예제 #3
0
    def test_spm_generator(self):
        test_file = os.path.join(os.path.dirname(__file__), '../../../data/spm/webank/example.txt')
        x_train, y_train = load_spm_data_and_labels(test_file)

        preprocessor = SPMPreprocessor(x_train, y_train)
        generator = SPMGenerator(preprocessor, x_train, batch_size=64)
        assert len(generator) == math.ceil(len(x_train[0]) / 64)
        for i, (features, y) in enumerate(generator):
            if i < len(generator) - 1:
                assert features[0].shape[0] == features[1].shape[0] == 64
                assert y is None
            else:
                assert features[0].shape[0] == features[1].shape[0] == \
                       len(x_train[0]) - 64 * (len(generator) - 1)
                assert y is None
예제 #4
0
    def setup_class(self):
        x_train, y_train = load_spm_data_and_labels(self.test_file)
        self.preprocessor = SPMPreprocessor(
            x_train,
            y_train,
            use_word=True,
            use_char=True,
            use_bert=False,
            bert_vocab_file=self.bert_vocab_file,
            char_embed_type='word2vec',
            word_embed_type='word2vec',
            max_len=10)
        self.num_class = self.preprocessor.num_class
        self.char_embeddings = self.preprocessor.char_embeddings
        self.char_vocab_size = self.preprocessor.char_vocab_size
        self.char_embed_dim = self.preprocessor.char_embed_dim

        self.word_embeddings = self.preprocessor.word_embeddings
        self.word_vocab_size = self.preprocessor.word_vocab_size
        self.word_embed_dim = self.preprocessor.word_embed_dim
        self.checkpoint_dir = os.path.dirname(__file__)
예제 #5
0
    def test_train_no_word(self):
        preprocessor = SPMPreprocessor(
            (self.train_data[0] + self.valid_data[0],
             self.train_data[1] + self.valid_data[1]),
            self.train_labels + self.valid_labels,
            use_word=False,
            use_char=True,
            use_bert=True,
            bert_vocab_file=self.bert_vocab_file,
            char_embed_type='word2vec',
            max_len=10)
        self.num_class = preprocessor.num_class
        self.char_embeddings = preprocessor.char_embeddings
        self.char_vocab_size = preprocessor.char_vocab_size
        self.char_embed_dim = preprocessor.char_embed_dim

        spm_model = SiameseCNN(num_class=self.num_class,
                               use_word=False,
                               use_char=True,
                               char_embeddings=self.char_embeddings,
                               char_vocab_size=self.char_vocab_size,
                               char_embed_dim=self.char_embed_dim,
                               char_embed_trainable=False,
                               use_bert=True,
                               bert_config_file=self.bert_config_file,
                               bert_checkpoint_file=self.bert_model_file,
                               max_len=preprocessor.max_len).build_model()

        spm_trainer = SPMTrainer(spm_model, preprocessor)
        spm_trainer.train(self.train_data,
                          self.train_labels,
                          self.valid_data,
                          self.valid_labels,
                          batch_size=6,
                          epochs=2)
        assert not os.path.exists(self.json_file)
        assert not os.path.exists(self.weights_file)
예제 #6
0
    def setup_class(self):
        self.train_data, self.train_labels, self.valid_data, self.valid_labels = \
            load_spm_data_and_labels(self.test_file, split_mode=1)
        self.preprocessor = SPMPreprocessor(
            (self.train_data[0] + self.valid_data[0],
             self.train_data[1] + self.valid_data[1]),
            self.train_labels + self.valid_labels,
            use_word=True,
            use_char=True,
            bert_vocab_file=self.bert_vocab_file,
            word_embed_type='word2vec',
            char_embed_type='word2vec',
            max_len=10)
        self.num_class = self.preprocessor.num_class
        self.char_embeddings = self.preprocessor.char_embeddings
        self.char_vocab_size = self.preprocessor.char_vocab_size
        self.char_embed_dim = self.preprocessor.char_embed_dim

        self.word_embeddings = self.preprocessor.word_embeddings
        self.word_vocab_size = self.preprocessor.word_vocab_size
        self.word_embed_dim = self.preprocessor.word_embed_dim
        self.checkpoint_dir = os.path.dirname(__file__)

        self.spm_model = SiameseCNN(
            num_class=self.num_class,
            use_word=True,
            word_embeddings=self.word_embeddings,
            word_vocab_size=self.word_vocab_size,
            word_embed_dim=self.word_embed_dim,
            word_embed_trainable=False,
            use_char=True,
            char_embeddings=self.char_embeddings,
            char_vocab_size=self.char_vocab_size,
            char_embed_dim=self.char_embed_dim,
            char_embed_trainable=False,
            use_bert=False,
            bert_config_file=self.bert_config_file,
            bert_checkpoint_file=self.bert_model_file,
            bert_trainable=True,
            max_len=self.preprocessor.max_len,
            max_word_len=self.preprocessor.max_word_len).build_model()

        self.swa_model = SiameseCNN(
            num_class=self.num_class,
            use_word=True,
            word_embeddings=self.word_embeddings,
            word_vocab_size=self.word_vocab_size,
            word_embed_dim=self.word_embed_dim,
            word_embed_trainable=False,
            use_char=True,
            char_embeddings=self.char_embeddings,
            char_vocab_size=self.char_vocab_size,
            char_embed_dim=self.char_embed_dim,
            char_embed_trainable=False,
            use_bert=False,
            bert_config_file=self.bert_config_file,
            bert_checkpoint_file=self.bert_model_file,
            bert_trainable=True,
            max_len=self.preprocessor.max_len,
            max_word_len=self.preprocessor.max_word_len).build_model()

        self.spm_trainer = SPMTrainer(self.spm_model, self.preprocessor)

        self.json_file = os.path.join(self.checkpoint_dir,
                                      'siamese_cnn_spm.json')
        self.weights_file = os.path.join(self.checkpoint_dir,
                                         'siamese_cnn_spm.hdf5')
예제 #7
0
파일: spm.py 프로젝트: zouxiaoshi/fancy-nlp
    def fit(self,
            train_data: Tuple[List[str], List[str]],
            train_labels: List[str],
            valid_data: Optional[Tuple[List[str], List[str]]] = None,
            valid_labels: Optional[List[str]] = None,
            spm_model_type: str = 'siamese_cnn',
            use_word: bool = True,
            external_word_dict: List[str] = None,
            word_embed_type: Optional[str] = 'word2vec',
            word_embed_dim: int = 300,
            word_embed_trainable: bool = True,
            use_char: bool = False,
            char_embed_type: Optional[str] = 'word2vec',
            char_embed_dim: int = 300,
            char_embed_trainable: bool = True,
            use_bert: bool = False,
            bert_vocab_file: Optional[str] = None,
            bert_config_file: Optional = None,
            bert_checkpoint_file: Optional = None,
            bert_trainable: bool = False,
            max_len: Optional[int] = None,
            max_word_len: Optional[int] = None,
            optimizer: Union[str, tf.keras.optimizers.Optimizer] = 'adam',
            batch_size: int = 32,
            epochs: int = 50,
            callback_list: Optional[List[str]] = None,
            checkpoint_dir: Optional = None,
            model_name: Optional = None,
            load_swa_model: bool = False,
            **kwargs) -> None:
        """Train spm model using provided data

        Args:
            train_data: list of untokenized text pairs for training,
                        like ``[['我是中国人', ...], ['我爱中国', ...]]``
            train_labels: labels string of train_data
            valid_data: list of untokenized text pairs for evaluation
            valid_labels: labels string of valid data
            spm_model_type: str, which spm model to use
            use_word: boolean, whether to use word embedding as input
            external_word_dict: external word dictionary
            word_embed_type: str, can be a pre-trained embedding filename or pre-trained embedding
                             methods (word2vec, glove, fastext)
            word_embed_dim: int, dimensionality of word embedding
            word_embed_trainable: boolean, whether to update word embedding during training
            use_char: boolean, whether to use char as input
            char_embed_type: str, similar as 'word_embed_type'
            char_embed_dim: int, similar as 'word_embed_dim'
            char_embed_trainable: boolean, similar as 'word_embed_trainable'
            use_bert: boolean, whether to use bert embedding as input
            bert_vocab_file: str, path to bert's vocabulary file
            bert_config_file: str, path to bert's configuration file
            bert_checkpoint_file: str, path to bert's checkpoint file
            bert_trainable: boolean, whether to update bert during training
            use_bert_model: boolean, whether to use traditional bert model which combines two sentences
                            as one input
            max_len: int, max sequence length. If None, we dynamically use the max length of one batch
                     as max_len. However, max_len must be provided when using bert as input.
            max_word_len: int, max word length. If None, we dynamically use the max word length of one
                          batch as max_word_len.
            optimizer: str or instance of `keras.optimizers.Optimizer`, indicating the optimizer to
                       use during training
            batch_size: num of samples per gradient update
            epochs: num of epochs to train the model
            callback_list: list of str, each item indicates the callback to apply during training
                           Currently, we support using 'modelcheckpoint' for `ModelCheckpoint`
                           callback, 'earlystopping` for 'Earlystopping` callback, 'swa' for
                           'SWA' callback. We will automatically add `SPMMetric` callback when
                           valid_data and valid_labels are both provided.
            checkpoint_dir: str, directory to save spm model, must be provided when using
                            `ModelCheckpoint` or `SWA` callback.
            model_name: str, prefix of spm model's weights file must be provided when using
                        `ModelCheckpoint` or `SWA` callback.
                        For example, if checkpoint_dir is 'ckpt' and model_name is 'model', the
                        weights of spm model saved by `ModelCheckpoint` callback will be
                        'ckpt/model.hdf5' and by `SWA` callback will be 'ckpt/model_swa.hdf5'
            load_swa_model: boolean, whether to load swa model, only apply when using SWA Callback.
            **kwargs: other argument for building spm model, such as "rnn_units", "fc_dim" etc.
        """
        use_bert_model = True if spm_model_type == 'bert' else False

        # data preprocessing
        self.preprocessor = SPMPreprocessor(
            train_data=train_data,
            train_labels=train_labels,
            use_word=use_word,
            use_char=use_char,
            use_bert=use_bert,
            use_bert_model=use_bert_model,
            external_word_dict=external_word_dict,
            bert_vocab_file=bert_vocab_file,
            char_embed_type=char_embed_type,
            char_embed_dim=char_embed_dim,
            word_embed_type=word_embed_type,
            word_embed_dim=word_embed_dim,
            max_len=max_len,
            max_word_len=max_word_len)

        # build model
        self.model = self.get_spm_model(
            spm_model_type=spm_model_type,
            num_class=self.preprocessor.num_class,
            use_word=use_word,
            word_embeddings=self.preprocessor.word_embeddings,
            word_vocab_size=self.preprocessor.word_vocab_size,
            word_embed_dim=self.preprocessor.word_embed_dim,
            word_embed_trainable=word_embed_trainable,
            use_char=use_char,
            char_embeddings=self.preprocessor.char_embeddings,
            char_vocab_size=self.preprocessor.char_vocab_size,
            char_embed_dim=self.preprocessor.char_embed_dim,
            char_embed_trainable=char_embed_trainable,
            use_bert=use_bert,
            bert_config_file=bert_config_file,
            bert_checkpoint_file=bert_checkpoint_file,
            bert_trainable=bert_trainable,
            max_len=self.preprocessor.max_len,
            max_word_len=self.preprocessor.max_word_len,
            optimizer=optimizer,
            **kwargs)

        # build swa model
        if 'swa' in callback_list:
            swa_model = self.get_spm_model(
                spm_model_type=spm_model_type,
                num_class=self.preprocessor.num_class,
                use_word=use_word,
                word_embeddings=self.preprocessor.word_embeddings,
                word_vocab_size=self.preprocessor.word_vocab_size,
                word_embed_dim=self.preprocessor.word_embed_dim,
                word_embed_trainable=word_embed_trainable,
                use_char=use_char,
                char_embeddings=self.preprocessor.char_embeddings,
                char_vocab_size=self.preprocessor.char_vocab_size,
                char_embed_dim=self.preprocessor.char_embed_dim,
                char_embed_trainable=char_embed_trainable,
                use_bert=use_bert,
                bert_config_file=bert_config_file,
                bert_checkpoint_file=bert_checkpoint_file,
                bert_trainable=bert_trainable,
                max_len=self.preprocessor.max_len,
                max_word_len=self.preprocessor.max_word_len,
                optimizer=optimizer,
                **kwargs)
        else:
            swa_model = None

        # train model
        self.trainer = SPMTrainer(self.model, self.preprocessor)
        self.trainer.train_generator(train_data, train_labels, valid_data,
                                     valid_labels, batch_size, epochs,
                                     callback_list, checkpoint_dir, model_name,
                                     swa_model, load_swa_model)

        # predict model
        self.predictor = SPMPredictor(self.model, self.preprocessor)

        if valid_data is not None and valid_labels is not None:
            logging.info('Evaluating on validation data...')
            self.score(valid_data, valid_labels)
예제 #8
0
파일: spm.py 프로젝트: zouxiaoshi/fancy-nlp
class SPM(object):
    """SPM application"""
    def __init__(self, use_pretrained: bool = True) -> None:
        self.preprocessor = None
        self.model = None
        self.trainer = None
        self.predictor = None

        if use_pretrained:
            self.load_pretrained_model()

    def fit(self,
            train_data: Tuple[List[str], List[str]],
            train_labels: List[str],
            valid_data: Optional[Tuple[List[str], List[str]]] = None,
            valid_labels: Optional[List[str]] = None,
            spm_model_type: str = 'siamese_cnn',
            use_word: bool = True,
            external_word_dict: List[str] = None,
            word_embed_type: Optional[str] = 'word2vec',
            word_embed_dim: int = 300,
            word_embed_trainable: bool = True,
            use_char: bool = False,
            char_embed_type: Optional[str] = 'word2vec',
            char_embed_dim: int = 300,
            char_embed_trainable: bool = True,
            use_bert: bool = False,
            bert_vocab_file: Optional[str] = None,
            bert_config_file: Optional = None,
            bert_checkpoint_file: Optional = None,
            bert_trainable: bool = False,
            max_len: Optional[int] = None,
            max_word_len: Optional[int] = None,
            optimizer: Union[str, tf.keras.optimizers.Optimizer] = 'adam',
            batch_size: int = 32,
            epochs: int = 50,
            callback_list: Optional[List[str]] = None,
            checkpoint_dir: Optional = None,
            model_name: Optional = None,
            load_swa_model: bool = False,
            **kwargs) -> None:
        """Train spm model using provided data

        Args:
            train_data: list of untokenized text pairs for training,
                        like ``[['我是中国人', ...], ['我爱中国', ...]]``
            train_labels: labels string of train_data
            valid_data: list of untokenized text pairs for evaluation
            valid_labels: labels string of valid data
            spm_model_type: str, which spm model to use
            use_word: boolean, whether to use word embedding as input
            external_word_dict: external word dictionary
            word_embed_type: str, can be a pre-trained embedding filename or pre-trained embedding
                             methods (word2vec, glove, fastext)
            word_embed_dim: int, dimensionality of word embedding
            word_embed_trainable: boolean, whether to update word embedding during training
            use_char: boolean, whether to use char as input
            char_embed_type: str, similar as 'word_embed_type'
            char_embed_dim: int, similar as 'word_embed_dim'
            char_embed_trainable: boolean, similar as 'word_embed_trainable'
            use_bert: boolean, whether to use bert embedding as input
            bert_vocab_file: str, path to bert's vocabulary file
            bert_config_file: str, path to bert's configuration file
            bert_checkpoint_file: str, path to bert's checkpoint file
            bert_trainable: boolean, whether to update bert during training
            use_bert_model: boolean, whether to use traditional bert model which combines two sentences
                            as one input
            max_len: int, max sequence length. If None, we dynamically use the max length of one batch
                     as max_len. However, max_len must be provided when using bert as input.
            max_word_len: int, max word length. If None, we dynamically use the max word length of one
                          batch as max_word_len.
            optimizer: str or instance of `keras.optimizers.Optimizer`, indicating the optimizer to
                       use during training
            batch_size: num of samples per gradient update
            epochs: num of epochs to train the model
            callback_list: list of str, each item indicates the callback to apply during training
                           Currently, we support using 'modelcheckpoint' for `ModelCheckpoint`
                           callback, 'earlystopping` for 'Earlystopping` callback, 'swa' for
                           'SWA' callback. We will automatically add `SPMMetric` callback when
                           valid_data and valid_labels are both provided.
            checkpoint_dir: str, directory to save spm model, must be provided when using
                            `ModelCheckpoint` or `SWA` callback.
            model_name: str, prefix of spm model's weights file must be provided when using
                        `ModelCheckpoint` or `SWA` callback.
                        For example, if checkpoint_dir is 'ckpt' and model_name is 'model', the
                        weights of spm model saved by `ModelCheckpoint` callback will be
                        'ckpt/model.hdf5' and by `SWA` callback will be 'ckpt/model_swa.hdf5'
            load_swa_model: boolean, whether to load swa model, only apply when using SWA Callback.
            **kwargs: other argument for building spm model, such as "rnn_units", "fc_dim" etc.
        """
        use_bert_model = True if spm_model_type == 'bert' else False

        # data preprocessing
        self.preprocessor = SPMPreprocessor(
            train_data=train_data,
            train_labels=train_labels,
            use_word=use_word,
            use_char=use_char,
            use_bert=use_bert,
            use_bert_model=use_bert_model,
            external_word_dict=external_word_dict,
            bert_vocab_file=bert_vocab_file,
            char_embed_type=char_embed_type,
            char_embed_dim=char_embed_dim,
            word_embed_type=word_embed_type,
            word_embed_dim=word_embed_dim,
            max_len=max_len,
            max_word_len=max_word_len)

        # build model
        self.model = self.get_spm_model(
            spm_model_type=spm_model_type,
            num_class=self.preprocessor.num_class,
            use_word=use_word,
            word_embeddings=self.preprocessor.word_embeddings,
            word_vocab_size=self.preprocessor.word_vocab_size,
            word_embed_dim=self.preprocessor.word_embed_dim,
            word_embed_trainable=word_embed_trainable,
            use_char=use_char,
            char_embeddings=self.preprocessor.char_embeddings,
            char_vocab_size=self.preprocessor.char_vocab_size,
            char_embed_dim=self.preprocessor.char_embed_dim,
            char_embed_trainable=char_embed_trainable,
            use_bert=use_bert,
            bert_config_file=bert_config_file,
            bert_checkpoint_file=bert_checkpoint_file,
            bert_trainable=bert_trainable,
            max_len=self.preprocessor.max_len,
            max_word_len=self.preprocessor.max_word_len,
            optimizer=optimizer,
            **kwargs)

        # build swa model
        if 'swa' in callback_list:
            swa_model = self.get_spm_model(
                spm_model_type=spm_model_type,
                num_class=self.preprocessor.num_class,
                use_word=use_word,
                word_embeddings=self.preprocessor.word_embeddings,
                word_vocab_size=self.preprocessor.word_vocab_size,
                word_embed_dim=self.preprocessor.word_embed_dim,
                word_embed_trainable=word_embed_trainable,
                use_char=use_char,
                char_embeddings=self.preprocessor.char_embeddings,
                char_vocab_size=self.preprocessor.char_vocab_size,
                char_embed_dim=self.preprocessor.char_embed_dim,
                char_embed_trainable=char_embed_trainable,
                use_bert=use_bert,
                bert_config_file=bert_config_file,
                bert_checkpoint_file=bert_checkpoint_file,
                bert_trainable=bert_trainable,
                max_len=self.preprocessor.max_len,
                max_word_len=self.preprocessor.max_word_len,
                optimizer=optimizer,
                **kwargs)
        else:
            swa_model = None

        # train model
        self.trainer = SPMTrainer(self.model, self.preprocessor)
        self.trainer.train_generator(train_data, train_labels, valid_data,
                                     valid_labels, batch_size, epochs,
                                     callback_list, checkpoint_dir, model_name,
                                     swa_model, load_swa_model)

        # predict model
        self.predictor = SPMPredictor(self.model, self.preprocessor)

        if valid_data is not None and valid_labels is not None:
            logging.info('Evaluating on validation data...')
            self.score(valid_data, valid_labels)

    def score(self, valid_data: Tuple[List[str], List[str]],
              valid_labels: List[str]) -> float:
        """Return the f1 score of the model over validation data

        Args:
            valid_data: list of untokenized text pairs
            valid_labels: list of label strings

        Returns:

        """
        if self.trainer:
            return self.trainer.evaluate(valid_data, valid_labels)
        else:
            logging.fatal(
                'Trainer is None! Call fit() or load() to get trainer.')

    def predict(self, test_text: Tuple[str, str]) -> str:
        """Return prediction of the model for test data

        Args:
            test_text: a pair of untokenized text

        Returns:

        """
        if self.predictor:
            return self.predictor.matching(test_text)
        else:
            logging.fatal(
                'Predictor is None! Call fit() or load() to get predictor.')

    def predict_batch(self, test_texts: Tuple[List[str],
                                              List[str]]) -> List[str]:
        """Return predictions of the model for test data

        Args:
            test_texts: list of untokenized text pairs

        Returns:

        """
        if self.predictor:
            return self.predictor.matching_batch(test_texts)
        else:
            logging.fatal(
                'Predictor is None! Call fit() or load() to get predictor.')

    def analyze(self, text: Tuple[str, str]) -> Tuple[str, np.ndarray]:
        """Analyze text and return matching result with probability.

        Args:
            text: a pair of untokenized text
        Returns:

        """
        if self.predictor:
            return self.predictor.matching_with_prob(text)
        else:
            logging.fatal(
                'Predictor is None! Call fit() or load() to get predictor.')

    def analyze_batch(
            self, texts: Tuple[List[str],
                               List[str]]) -> List[Tuple[str, np.ndarray]]:
        """Analyze text and return matching result with probability.

        Args:
            texts: list of untokenized text pairs
        Returns:

        """
        if self.predictor:
            return self.predictor.matching_with_prob_batch(texts)
        else:
            logging.fatal(
                'Predictor is None! Call fit() or load() to get predictor.')

    def save(self,
             preprocessor_file: str,
             json_file: str,
             weights_file: Optional[str] = None) -> None:
        """save spm application

        Args:
            preprocessor_file: path to save preprocessor
            json_file: path to save model architecture
            weights_file: path to save model weights, can be None. When we use `ModelCheckpoint`
                          or `SWA` callback, model's weights will be saved to disk after training.
                          In that case, we don't need to save it again. We usually set weights_file
                          to be None.
        """
        self.preprocessor.save(preprocessor_file)
        logging.info('Save preprocessor to {}'.format(preprocessor_file))

        model_json = self.model.to_json()
        with open(json_file, 'w') as writer:
            writer.write(model_json)
        logging.info('Save model architecture to {}'.format(json_file))

        if weights_file:
            self.model.save_weights(weights_file)
            logging.info('Save model weights to {}'.format(weights_file))

    def load(self,
             preprocessor_file: str,
             json_file: str,
             weights_file: str,
             custom_objects: Optional[Dict[str, Any]] = None) -> None:
        """load spm application

        Args:
            preprocessor_file: path to load preprocessor
            json_file: path to load model architecture
            weights_file: path to load model weights
            custom_objects: Optional dictionary mapping names (strings) to custom classes or
                            functions to be considered during deserialization. Must provided when
                            using custom layer.

        """
        self.preprocessor = SPMPreprocessor.load(preprocessor_file)
        logging.info('Load preprocessor from {}'.format(preprocessor_file))

        custom_objects = custom_objects or {}
        custom_objects.update(get_custom_objects())
        with open(json_file, 'r') as reader:
            self.model = tf.keras.models.model_from_json(
                reader.read(), custom_objects=custom_objects)
        logging.info('Load model architecture from {}'.format(json_file))

        self.model.load_weights(weights_file)
        logging.info('Load model weight from {}'.format(weights_file))

        self.trainer = SPMTrainer(self.model, self.preprocessor)
        self.predictor = SPMPredictor(self.model, self.preprocessor)

    @staticmethod
    def get_spm_model(spm_model_type: str, num_class: int, use_word: bool,
                      word_embeddings: Optional[np.ndarray],
                      word_vocab_size: int, word_embed_dim: int,
                      word_embed_trainable: bool, use_char: bool,
                      char_embeddings: Optional[np.ndarray],
                      char_vocab_size: int, char_embed_dim: int,
                      char_embed_trainable: bool, use_bert: bool,
                      bert_config_file: Optional[str],
                      bert_checkpoint_file: Optional[str],
                      bert_trainable: bool, max_len: Optional[int],
                      max_word_len: Optional[int],
                      optimizer: Union[str, tf.keras.optimizers.Optimizer],
                      **kwargs) -> tf.keras.models.Model:
        """build spm models by model_type

        Args:
            spm_model_type: str, which spm model to use
            num_class: int: the number of classification class
            use_word: boolean, whether to use word embedding as input
            word_embeddings: np.ndarray, word embeddings
            word_vocab_size: int, the number of words in vocabulary
            word_embed_dim: int, dimensionality of word embedding
            word_embed_trainable: boolean, whether to update word embedding during training
            use_char: boolean, whether to use char as input
            char_embeddings: ndarray, char_embeddings
            char_vocab_size: int, the number of chars in vocabulary
            char_embed_dim: int, dimensionality of char embedding
            char_embed_trainable: boolean, similar as 'word_embed_trainable'
            use_bert: boolean, whether to use bert embedding as input
            bert_config_file: str, path to bert's configuration file
            bert_checkpoint_file: str, path to bert's checkpoint file
            bert_trainable: boolean, whether to update bert during training
            max_len: int, max sequence length. If None, we dynamically use the max length of one batch
                     as max_len. However, max_len must be provided when using bert as input.
            max_word_len: int, max word length. If None, we dynamically use the max word length of one
                          batch as max_word_len.
            optimizer: str or instance of `keras.optimizers.Optimizer`, indicating the optimizer to
                       use during training
            **kwargs: other argument for building spm model, such as "rnn_units", "fc_dim" etc.
        """
        spm_model_all = {
            'siamese_cnn': SiameseCNN,
            'siamese_bilstm': SiameseBiLSTM,
            'siamese_bigru': SiameseBiGRU,
            'esim': ESIM,
            'bimpm': BiMPM,
            'bert': BertSPM
        }
        if spm_model_type in spm_model_all:
            spm_model = spm_model_all[spm_model_type](
                num_class=num_class,
                use_word=use_word,
                word_embeddings=word_embeddings,
                word_vocab_size=word_vocab_size,
                word_embed_dim=word_embed_dim,
                word_embed_trainable=word_embed_trainable,
                use_char=use_char,
                char_embeddings=char_embeddings,
                char_vocab_size=char_vocab_size,
                char_embed_dim=char_embed_dim,
                char_embed_trainable=char_embed_trainable,
                use_bert=use_bert,
                bert_config_file=bert_config_file,
                bert_checkpoint_file=bert_checkpoint_file,
                bert_trainable=bert_trainable,
                max_len=max_len,
                max_word_len=max_word_len,
                optimizer=optimizer,
                **kwargs)

        else:
            raise ValueError(
                '`spm_model_type` not understood: {}'.format(spm_model_type))

        return spm_model.build_model()

    def load_pretrained_model(self) -> None:
        cache_subdir = 'pretrained_models'

        preprocessor_file = tf.keras.utils.get_file(
            fname='webank_spm_siamese_cnn_word_preprocessor.pkl',
            origin=MODEL_STORAGE_PREFIX +
            'webank_spm_siamese_cnn_word_preprocessor.pkl',
            cache_subdir=cache_subdir,
            cache_dir=CACHE_DIR)
        json_file = tf.keras.utils.get_file(
            fname='webank_spm_siamese_cnn_word.json',
            origin=MODEL_STORAGE_PREFIX + 'webank_spm_siamese_cnn_word.json',
            cache_subdir=cache_subdir,
            cache_dir=CACHE_DIR)
        weights_file = tf.keras.utils.get_file(
            fname='webank_spm_siamese_cnn_word.hdf5',
            origin=MODEL_STORAGE_PREFIX + 'webank_spm_siamese_cnn_word.hdf5',
            cache_subdir=cache_subdir,
            cache_dir=CACHE_DIR)

        self.load(preprocessor_file, json_file, weights_file)