Exemplo n.º 1
0
    def load_from_file(cls,
                       model_folder,
                       context: mx.Context = None,
                       **kwargs):
        if context is None:
            context = mxnet_prefer_gpu()
        config_path = os.path.join(model_folder, 'config.pkl')
        with open(config_path, 'rb') as f:
            config = pickle.load(f)
            with context:
                embedding_types = [
                    WordEmbeddings(
                        '{}data/embedding/fasttext100.vec.txt'.format(
                            kwargs.get('word_embedding_path', ''))),

                    # comment in this line to use character embeddings
                    # CharacterEmbeddings(),

                    # comment in these lines to use contextual string embeddings
                    CharLMEmbeddings('{}data/model/lm-news-forward'.format(
                        kwargs.get('word_embedding_path', '')),
                                     context=context),
                    CharLMEmbeddings('{}data/model/lm-news-backward'.format(
                        kwargs.get('word_embedding_path', '')),
                                     context=context),
                ]

                embeddings = StackedEmbeddings(embeddings=embedding_types)
                model = SequenceTagger(hidden_size=config['hidden_size'],
                                       embeddings=embeddings,
                                       tag_dictionary=config['tag_dictionary'],
                                       tag_type=config['tag_type'],
                                       use_crf=config['use_crf'],
                                       use_rnn=config['use_rnn'],
                                       rnn_layers=config['rnn_layers'])
                model.load_parameters(os.path.join(model_folder, 'model.bin'),
                                      ctx=context)
            return model
Exemplo n.º 2
0
        WordEmbeddings('data/embedding/glove/glove.6B.100d.txt'),
        BERTEmbeddings([
            'data/embedding/bert_large_sum/conll03.train.bert',
            'data/embedding/bert_large_sum/conll03.dev.bert',
            'data/embedding/bert_large_sum/conll03.test.bert'
        ]),

        # comment in this line to use character embeddings
        # CharacterEmbeddings(),

        # comment in these lines to use contextual string embeddings
        CharLMEmbeddings('data/model/lm-news-forward'),
        CharLMEmbeddings('data/model/lm-news-backward'),
    ]

    embeddings = StackedEmbeddings(embeddings=embedding_types)

    # 5. initialize sequence tagger
    tagger = SequenceTagger(hidden_size=256,
                            embeddings=embeddings,
                            tag_dictionary=tag_dictionary,
                            tag_type=tag_type,
                            use_crf=True)

    # 6. initialize trainer
    trainer = SequenceTaggerTrainer(tagger, corpus, test_mode=False)

    # 7. start training
    trainer.train(model_path,
                  learning_rate=0.1,
                  mini_batch_size=32,
Exemplo n.º 3
0
columns = {0: 'text', 1: 'pos'}
corpus = NLPTaskDataFetcher.fetch_column_corpus('data/wsj-pos',
                                                columns,
                                                train_file='train.short.tsv',
                                                test_file='test.tsv',
                                                dev_file='dev.tsv')
# 2. what tag do we want to predict?
tag_type = 'pos'

# 3. make the tag dictionary from the corpus
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
print(tag_dictionary.idx2item)

# 4. initialize embeddings
with mx.Context(mxnet_prefer_gpu()):
    embeddings = TwoWayEmbeddings(StackedEmbeddings(
        [CharLMEmbeddings('data/model/lm-news-forward'), CharLMEmbeddings('data/model/lm-news-backward')]),
        BERTEmbeddings(['data/embedding/bert_large_cased/wsj.train.short.bert',
                        'data/embedding/bert_large_cased/wsj.dev.bert',
                        'data/embedding/bert_large_cased/wsj.test.bert']),
        128, 128)

    # 5. initialize sequence tagger
    tagger = SequenceTagger(hidden_size=256,
                            embeddings=embeddings,
                            tag_dictionary=tag_dictionary,
                            tag_type=tag_type,
                            use_crf=False)

    # 6. initialize trainer
    trainer = SequenceTaggerTrainer(tagger, corpus, test_mode=False)