Beispiel #1
0
    def load_from_file(cls,
                       model_folder,
                       context: mx.Context = None,
                       **kwargs):
        if context is None:
            context = mxnet_prefer_gpu()
        config_path = os.path.join(model_folder, 'config.pkl')
        with open(config_path, 'rb') as f:
            config = pickle.load(f)
            with context:
                embedding_types = [
                    WordEmbeddings(
                        '{}data/embedding/fasttext100.vec.txt'.format(
                            kwargs.get('word_embedding_path', ''))),

                    # comment in this line to use character embeddings
                    # CharacterEmbeddings(),

                    # comment in these lines to use contextual string embeddings
                    CharLMEmbeddings('{}data/model/lm-news-forward'.format(
                        kwargs.get('word_embedding_path', '')),
                                     context=context),
                    CharLMEmbeddings('{}data/model/lm-news-backward'.format(
                        kwargs.get('word_embedding_path', '')),
                                     context=context),
                ]

                embeddings = StackedEmbeddings(embeddings=embedding_types)
                model = SequenceTagger(hidden_size=config['hidden_size'],
                                       embeddings=embeddings,
                                       tag_dictionary=config['tag_dictionary'],
                                       tag_type=config['tag_type'],
                                       use_crf=config['use_crf'],
                                       use_rnn=config['use_rnn'],
                                       rnn_layers=config['rnn_layers'])
                model.load_parameters(os.path.join(model_folder, 'model.bin'),
                                      ctx=context)
            return model
Beispiel #2
0
# 4. initialize embeddings
with mx.Context(mxnet_prefer_gpu()):
    embedding_types = [
        WordEmbeddings('data/embedding/glove/glove.6B.100d.txt'),
        BERTEmbeddings([
            'data/embedding/bert_large_sum/conll03.train.bert',
            'data/embedding/bert_large_sum/conll03.dev.bert',
            'data/embedding/bert_large_sum/conll03.test.bert'
        ]),

        # comment in this line to use character embeddings
        # CharacterEmbeddings(),

        # comment in these lines to use contextual string embeddings
        CharLMEmbeddings('data/model/lm-news-forward'),
        CharLMEmbeddings('data/model/lm-news-backward'),
    ]

    embeddings = StackedEmbeddings(embeddings=embedding_types)

    # 5. initialize sequence tagger
    tagger = SequenceTagger(hidden_size=256,
                            embeddings=embeddings,
                            tag_dictionary=tag_dictionary,
                            tag_type=tag_type,
                            use_crf=True)

    # 6. initialize trainer
    trainer = SequenceTaggerTrainer(tagger, corpus, test_mode=False)