示例#1
0
def _load_embedder(config, bert_max_length):
    if config.embedder.name == 'elmo':
        embedder = ElmoTokenEmbedder(
            options_file=os.path.join(config.data.pretrained_models_dir,
                                      'elmo/options.json'),
            weight_file=os.path.join(config.data.pretrained_models_dir,
                                     'elmo/model.hdf5'),
            dropout=0.)
        embedder.eval()
    elif config.embedder.name.endswith('bert'):
        embedder = PretrainedTransformerMismatchedEmbedder(
            model_name=os.path.join(config.data.pretrained_models_dir,
                                    config.embedder.name),
            max_length=bert_max_length)
    elif config.embedder.name == 'both':
        elmo_embedder = ElmoTokenEmbedder(
            options_file=os.path.join(config.data.pretrained_models_dir,
                                      'elmo/options.json'),
            weight_file=os.path.join(config.data.pretrained_models_dir,
                                     'elmo/model.hdf5'),
            dropout=0.)
        elmo_embedder.eval()

        bert_embedder = PretrainedTransformerMismatchedEmbedder(
            model_name=os.path.join(config.data.pretrained_models_dir,
                                    'ru_bert'),
            max_length=bert_max_length)

        return BasicTextFieldEmbedder({
            'elmo': elmo_embedder,
            'ru_bert': bert_embedder
        })
    else:
        assert False, 'Unknown embedder {}'.format(config.embedder.name)

    return BasicTextFieldEmbedder({config.embedder.name: embedder})
示例#2
0
def _load_embedder(config, vocab, bert_max_length):
    embedders = {}
    for embedder_config in config.embedder.models:
        if embedder_config.name == 'elmo':
            embedders[embedder_config.name] = ElmoTokenEmbedder(
                options_file=os.path.join(config.data.pretrained_models_dir,
                                          'elmo/options.json'),
                weight_file=os.path.join(config.data.pretrained_models_dir,
                                         'elmo/model.hdf5'),
                requires_grad=embedder_config.params['requires_grad'],
                dropout=0.)
            embedders[embedder_config.name].eval()
        elif embedder_config.name.endswith('bert'):
            embedders[
                embedder_config.
                name] = PretrainedTransformerMismatchedEmbedder(
                    model_name=os.path.join(config.data.pretrained_models_dir,
                                            embedder_config.name),
                    max_length=bert_max_length,
                    requires_grad=embedder_config.params['requires_grad'])
        elif embedder_config.name == 'char_bilstm':
            embedders[embedder_config.name] = TokenCharactersEncoder(
                embedding=Embedding(
                    num_embeddings=vocab.get_vocab_size('token_characters'),
                    embedding_dim=embedder_config.params['char_embedding_dim']
                ),
                encoder=PytorchSeq2VecWrapper(
                    torch.nn.LSTM(
                        embedder_config.params['char_embedding_dim'],
                        embedder_config.params['lstm_dim'],
                        num_layers=embedder_config.params['lstm_num_layers'],
                        dropout=embedder_config.params['lstm_dropout'],
                        bidirectional=True,
                        batch_first=True)),
                dropout=embedder_config.params['dropout'])
        else:
            assert False, 'Unknown embedder {}'.format(embedder_config.name)

    return BasicTextFieldEmbedder(embedders)
示例#3
0
def build(params, vocab):
    if params["with_bert"]:
        from allennlp.modules.token_embedders.pretrained_transformer_mismatched_embedder import PretrainedTransformerMismatchedEmbedder
        embedding = PretrainedTransformerMismatchedEmbedder(
            model_name=params["bert_name"], max_length=params["bert_max_len"])
    else:
        from allennlp.modules.token_embedders import Embedding
        embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                              embedding_dim=params["dim_emb"])
    embedder = BasicTextFieldEmbedder({"tokens": embedding})
    encoder = PytorchSeq2SeqWrapper(
        torch.nn.LSTM(params["dim_emb"],
                      params["dim_hid"],
                      params["num_enc_layers"],
                      dropout=params["dropout"],
                      bidirectional=True,
                      batch_first=True))
    model = CrfTagger(vocab=vocab,
                      text_field_embedder=embedder,
                      encoder=encoder,
                      dropout=params["dropout"])
    return model