def _load_embedder(config, bert_max_length): if config.embedder.name == 'elmo': embedder = ElmoTokenEmbedder( options_file=os.path.join(config.data.pretrained_models_dir, 'elmo/options.json'), weight_file=os.path.join(config.data.pretrained_models_dir, 'elmo/model.hdf5'), dropout=0.) embedder.eval() elif config.embedder.name.endswith('bert'): embedder = PretrainedTransformerMismatchedEmbedder( model_name=os.path.join(config.data.pretrained_models_dir, config.embedder.name), max_length=bert_max_length) elif config.embedder.name == 'both': elmo_embedder = ElmoTokenEmbedder( options_file=os.path.join(config.data.pretrained_models_dir, 'elmo/options.json'), weight_file=os.path.join(config.data.pretrained_models_dir, 'elmo/model.hdf5'), dropout=0.) elmo_embedder.eval() bert_embedder = PretrainedTransformerMismatchedEmbedder( model_name=os.path.join(config.data.pretrained_models_dir, 'ru_bert'), max_length=bert_max_length) return BasicTextFieldEmbedder({ 'elmo': elmo_embedder, 'ru_bert': bert_embedder }) else: assert False, 'Unknown embedder {}'.format(config.embedder.name) return BasicTextFieldEmbedder({config.embedder.name: embedder})
def _load_embedder(config, vocab, bert_max_length): embedders = {} for embedder_config in config.embedder.models: if embedder_config.name == 'elmo': embedders[embedder_config.name] = ElmoTokenEmbedder( options_file=os.path.join(config.data.pretrained_models_dir, 'elmo/options.json'), weight_file=os.path.join(config.data.pretrained_models_dir, 'elmo/model.hdf5'), requires_grad=embedder_config.params['requires_grad'], dropout=0.) embedders[embedder_config.name].eval() elif embedder_config.name.endswith('bert'): embedders[ embedder_config. name] = PretrainedTransformerMismatchedEmbedder( model_name=os.path.join(config.data.pretrained_models_dir, embedder_config.name), max_length=bert_max_length, requires_grad=embedder_config.params['requires_grad']) elif embedder_config.name == 'char_bilstm': embedders[embedder_config.name] = TokenCharactersEncoder( embedding=Embedding( num_embeddings=vocab.get_vocab_size('token_characters'), embedding_dim=embedder_config.params['char_embedding_dim'] ), encoder=PytorchSeq2VecWrapper( torch.nn.LSTM( embedder_config.params['char_embedding_dim'], embedder_config.params['lstm_dim'], num_layers=embedder_config.params['lstm_num_layers'], dropout=embedder_config.params['lstm_dropout'], bidirectional=True, batch_first=True)), dropout=embedder_config.params['dropout']) else: assert False, 'Unknown embedder {}'.format(embedder_config.name) return BasicTextFieldEmbedder(embedders)
def build(params, vocab): if params["with_bert"]: from allennlp.modules.token_embedders.pretrained_transformer_mismatched_embedder import PretrainedTransformerMismatchedEmbedder embedding = PretrainedTransformerMismatchedEmbedder( model_name=params["bert_name"], max_length=params["bert_max_len"]) else: from allennlp.modules.token_embedders import Embedding embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'), embedding_dim=params["dim_emb"]) embedder = BasicTextFieldEmbedder({"tokens": embedding}) encoder = PytorchSeq2SeqWrapper( torch.nn.LSTM(params["dim_emb"], params["dim_hid"], params["num_enc_layers"], dropout=params["dropout"], bidirectional=True, batch_first=True)) model = CrfTagger(vocab=vocab, text_field_embedder=embedder, encoder=encoder, dropout=params["dropout"]) return model