def test_siamese_bilstm_model(self): # word, char spm_model = SiameseBiLSTM(num_class=self.num_class, use_word=True, word_embeddings=self.word_embeddings, word_vocab_size=self.word_vocab_size, word_embed_dim=self.word_embed_dim, word_embed_trainable=False, use_char=True, char_embeddings=self.char_embeddings, char_vocab_size=self.char_vocab_size, char_embed_dim=self.char_embed_dim, char_embed_trainable=False, use_bert=False, max_len=10).build_model() # char, bert spm_model = SiameseBiLSTM(num_class=self.num_class, use_word=False, use_char=True, char_embeddings=self.char_embeddings, char_vocab_size=self.char_vocab_size, char_embed_dim=self.char_embed_dim, char_embed_trainable=False, use_bert=True, bert_config_file=self.bert_config_file, bert_checkpoint_file=self.bert_model_file, max_len=10).build_model() # test save and load json_file = os.path.join(self.checkpoint_dir, 'siamese_bilstm_spm.json') weights_file = os.path.join(self.checkpoint_dir, 'siamese_bilstm_spm.hdf5') save_keras_model(spm_model, json_file, weights_file) assert os.path.exists(json_file) assert os.path.exists(weights_file) load_keras_model(json_file, weights_file, custom_objects=get_custom_objects()) os.remove(json_file) os.remove(weights_file) assert not os.path.exists(json_file) assert not os.path.exists(weights_file)
def test_bert_model(self): spm_model = BertSPM(num_class=self.num_class, bert_config_file=self.bert_config_file, bert_checkpoint_file=self.bert_model_file, bert_trainable=True, max_len=10).build_model() # test save and load json_file = os.path.join(self.checkpoint_dir, 'bert_spm.json') weights_file = os.path.join(self.checkpoint_dir, 'bert_spm.hdf5') save_keras_model(spm_model, json_file, weights_file) assert os.path.exists(json_file) assert os.path.exists(weights_file) load_keras_model(json_file, weights_file, custom_objects=get_custom_objects()) os.remove(json_file) os.remove(weights_file) assert not os.path.exists(json_file) assert not os.path.exists(weights_file)
def test_load(self): self.model = load_keras_model(self.test_json_file, self.test_weights_file, custom_objects={'MyLayer': MyLayer}) assert len(self.model.layers) == 3
def test_bilstm_cnn_model(self): # char, no CRF, no word input ner_model = BiLSTMCNNNER(num_class=self.num_class, char_embeddings=self.char_embeddings, char_vocab_size=self.char_vocab_size, char_embed_dim=self.char_embed_dim, char_embed_trainable=False, use_word=False, use_crf=False).build_model() # char, CRF, no word, no bert input ner_model = BiLSTMCNNNER(num_class=self.num_class, char_embeddings=self.char_embeddings, char_vocab_size=self.char_vocab_size, char_embed_dim=self.char_embed_dim, char_embed_trainable=False, use_word=False, use_crf=True).build_model() # char, CRF, word, no bert input ner_model = BiLSTMCNNNER(num_class=self.num_class, char_embeddings=self.char_embeddings, char_vocab_size=self.char_vocab_size, char_embed_dim=self.char_embed_dim, char_embed_trainable=False, use_word=True, word_embeddings=self.word_embeddings, word_vocab_size=self.word_vocab_size, word_embed_dim=self.word_embed_dim, word_embed_trainable=False, use_crf=True).build_model() # char, CRF, word, bert ner_model = BiLSTMCNNNER(num_class=self.num_class, char_embeddings=self.char_embeddings, char_vocab_size=self.char_vocab_size, char_embed_dim=self.char_embed_dim, char_embed_trainable=False, use_bert=True, bert_config_file=self.bert_config_file, bert_checkpoint_file=self.bert_model_file, bert_trainable=True, use_word=True, word_embeddings=self.word_embeddings, word_vocab_size=self.word_vocab_size, word_embed_dim=self.word_embed_dim, word_embed_trainable=False, max_len=16, use_crf=True).build_model() # test save and load json_file = os.path.join(self.checkpoint_dir, 'bilstm_cnn_ner.json') weights_file = os.path.join(self.checkpoint_dir, 'bilstm_cnn_ner.hdf5') save_keras_model(ner_model, json_file, weights_file) assert os.path.exists(json_file) assert os.path.exists(weights_file) load_keras_model(json_file, weights_file, custom_objects=get_custom_objects()) os.remove(json_file) os.remove(weights_file) assert not os.path.exists(json_file) assert not os.path.exists(weights_file)