コード例 #1
0
    def test_siamese_bilstm_model(self):
        # word, char
        spm_model = SiameseBiLSTM(num_class=self.num_class,
                                  use_word=True,
                                  word_embeddings=self.word_embeddings,
                                  word_vocab_size=self.word_vocab_size,
                                  word_embed_dim=self.word_embed_dim,
                                  word_embed_trainable=False,
                                  use_char=True,
                                  char_embeddings=self.char_embeddings,
                                  char_vocab_size=self.char_vocab_size,
                                  char_embed_dim=self.char_embed_dim,
                                  char_embed_trainable=False,
                                  use_bert=False,
                                  max_len=10).build_model()

        # char, bert
        spm_model = SiameseBiLSTM(num_class=self.num_class,
                                  use_word=False,
                                  use_char=True,
                                  char_embeddings=self.char_embeddings,
                                  char_vocab_size=self.char_vocab_size,
                                  char_embed_dim=self.char_embed_dim,
                                  char_embed_trainable=False,
                                  use_bert=True,
                                  bert_config_file=self.bert_config_file,
                                  bert_checkpoint_file=self.bert_model_file,
                                  max_len=10).build_model()

        # test save and load
        json_file = os.path.join(self.checkpoint_dir,
                                 'siamese_bilstm_spm.json')
        weights_file = os.path.join(self.checkpoint_dir,
                                    'siamese_bilstm_spm.hdf5')

        save_keras_model(spm_model, json_file, weights_file)
        assert os.path.exists(json_file)
        assert os.path.exists(weights_file)

        load_keras_model(json_file,
                         weights_file,
                         custom_objects=get_custom_objects())
        os.remove(json_file)
        os.remove(weights_file)
        assert not os.path.exists(json_file)
        assert not os.path.exists(weights_file)
コード例 #2
0
    def test_bert_model(self):
        spm_model = BertSPM(num_class=self.num_class,
                            bert_config_file=self.bert_config_file,
                            bert_checkpoint_file=self.bert_model_file,
                            bert_trainable=True,
                            max_len=10).build_model()

        # test save and load
        json_file = os.path.join(self.checkpoint_dir, 'bert_spm.json')
        weights_file = os.path.join(self.checkpoint_dir, 'bert_spm.hdf5')

        save_keras_model(spm_model, json_file, weights_file)
        assert os.path.exists(json_file)
        assert os.path.exists(weights_file)

        load_keras_model(json_file,
                         weights_file,
                         custom_objects=get_custom_objects())
        os.remove(json_file)
        os.remove(weights_file)
        assert not os.path.exists(json_file)
        assert not os.path.exists(weights_file)
コード例 #3
0
 def test_load(self):
     self.model = load_keras_model(self.test_json_file,
                                   self.test_weights_file,
                                   custom_objects={'MyLayer': MyLayer})
     assert len(self.model.layers) == 3
コード例 #4
0
ファイル: test_ner_models.py プロジェクト: techwitz/fancy-nlp
    def test_bilstm_cnn_model(self):
        # char, no CRF, no word input
        ner_model = BiLSTMCNNNER(num_class=self.num_class,
                                 char_embeddings=self.char_embeddings,
                                 char_vocab_size=self.char_vocab_size,
                                 char_embed_dim=self.char_embed_dim,
                                 char_embed_trainable=False,
                                 use_word=False,
                                 use_crf=False).build_model()

        # char, CRF, no word, no bert input
        ner_model = BiLSTMCNNNER(num_class=self.num_class,
                                 char_embeddings=self.char_embeddings,
                                 char_vocab_size=self.char_vocab_size,
                                 char_embed_dim=self.char_embed_dim,
                                 char_embed_trainable=False,
                                 use_word=False,
                                 use_crf=True).build_model()

        # char, CRF, word, no bert input
        ner_model = BiLSTMCNNNER(num_class=self.num_class,
                                 char_embeddings=self.char_embeddings,
                                 char_vocab_size=self.char_vocab_size,
                                 char_embed_dim=self.char_embed_dim,
                                 char_embed_trainable=False,
                                 use_word=True,
                                 word_embeddings=self.word_embeddings,
                                 word_vocab_size=self.word_vocab_size,
                                 word_embed_dim=self.word_embed_dim,
                                 word_embed_trainable=False,
                                 use_crf=True).build_model()

        # char, CRF, word, bert
        ner_model = BiLSTMCNNNER(num_class=self.num_class,
                                 char_embeddings=self.char_embeddings,
                                 char_vocab_size=self.char_vocab_size,
                                 char_embed_dim=self.char_embed_dim,
                                 char_embed_trainable=False,
                                 use_bert=True,
                                 bert_config_file=self.bert_config_file,
                                 bert_checkpoint_file=self.bert_model_file,
                                 bert_trainable=True,
                                 use_word=True,
                                 word_embeddings=self.word_embeddings,
                                 word_vocab_size=self.word_vocab_size,
                                 word_embed_dim=self.word_embed_dim,
                                 word_embed_trainable=False,
                                 max_len=16,
                                 use_crf=True).build_model()

        # test save and load
        json_file = os.path.join(self.checkpoint_dir, 'bilstm_cnn_ner.json')
        weights_file = os.path.join(self.checkpoint_dir, 'bilstm_cnn_ner.hdf5')

        save_keras_model(ner_model, json_file, weights_file)
        assert os.path.exists(json_file)
        assert os.path.exists(weights_file)

        load_keras_model(json_file,
                         weights_file,
                         custom_objects=get_custom_objects())
        os.remove(json_file)
        os.remove(weights_file)
        assert not os.path.exists(json_file)
        assert not os.path.exists(weights_file)