def out_embedding(type_, model, n_layers, stacked=False): ''' Create object of embedding type for later purpose :param: :type_: (str) type of embedding (currently there are only BERT or Flair embeddings) :model: (str) pretrained model of BERT embedding :n_layers: (int) number of last layers of trained BERT embeddings to be chosen :stacked: (bool) if this embedding is a combination of more embeddings (True/False) :return: :embedding: (BertEmbeddings / StackedEmbeddings) embedding object ''' out_layers = ','.join([str(-i) for i in range(1, n_layers + 1)]) if not stacked: if type_.lower() == 'bert': embedding = BertEmbeddings(bert_model_or_path=model, layers=out_layers) return embedding else: emb = WordEmbeddings('glove') else: emb = BertEmbeddings(bert_model_or_path=model, layers=out_layers) flair_forward = FlairEmbeddings('news-forward-fast') flair_backward = FlairEmbeddings('news-backward-fast') embedding = StackedEmbeddings( embeddings=[flair_forward, flair_backward, emb]) return embedding
def main(): args = _parse_args() tsv_path = args.tsv_path embedding = BertEmbeddings('bert-base-cased') sentences = [[]] with open(tsv_path, 'r') as f: for i, l in enumerate(f.readlines()): if l.strip(): token, *_ = l.strip().split('\t') sentences[-1].append(token.lower()) else: sentences.append([]) f_sentences = [Sentence(' '.join(s)) for s in sentences] for s in progressbar.progressbar(f_sentences): embedding.embed(s) for t in s: print('\t'.join(t.embedding.numpy().astype(str))) print() s.clear_embeddings()
def initialize_embeddings(self, fastbert=True, stackedembeddings=True): # Consider using pooling_operation="first", use_scalar_mix=True for the parameters # initialize individual embeddings if fastbert: bert_embedding = BertEmbeddings('distilbert-base-uncased', layers='-1') else: bert_embedding = BertEmbeddings('bert-base-cased', layers='-1') if stackedembeddings: glove_embedding = WordEmbeddings('glove') # init Flair forward and backwards embeddings flair_embedding_forward = FlairEmbeddings('news-forward') flair_embedding_backward = FlairEmbeddings('news-backward') embedding_types = [ bert_embedding, glove_embedding, flair_embedding_forward, flair_embedding_backward ] embeddings = StackedEmbeddings(embeddings=embedding_types) else: embeddings = bert_embedding return embeddings
def __init__(self, model: Optional[BertEmbeddings] = None): super(BertPretrained, self).__init__() if model is not None: self.model = model else: self.model = BertEmbeddings('bert-base-uncased')
def __init__(self, len, emb='en'): """ Args: len (int): max length for the model input lang (str, optional): embedding language. Defaults to 'en'. """ if emb=='en': self.embedder = BertEmbeddings("distilbert-base-uncased") self.MAX_LEN = len
class BertPretrained(ModelBase): """ Encapsulates pretrained Bert Embeddings (from Zalando Flair) by conforming to the ModelBase interface. """ def __init__(self, model: Optional[BertEmbeddings] = None): super(BertPretrained, self).__init__() if model is not None: self.model = model else: self.model = BertEmbeddings('bert-base-uncased') def dim(self) -> int: """ The dimensionality of created embeddings. :return: 3072 (for now, #TODO) """ return 3072 def get_word_vector(self, word: str) -> Optional[np.ndarray]: """ Returns the word vector for word |word| or None. It is discouraged to use this method as it invalidates the purpose of Bert embeddings. Instead, utilize the context as well for more accurate vectorization. In reality, Bert embeddings never return None, even for bogus words. :param word: The word to vectorize. :return: Either the word vector or None. """ dummy_sentence = Sentence(word) self.model.embed(dummy_sentence) return np.array(list(dummy_sentence)[0].embedding) def get_word_vectors(self, words: List[str]) -> List[np.ndarray]: """ Vectorizes the list of words, using pretrained Bert embeddings. These embeddings are context dependent, so this method is preferred over fetching word vectors for single words. :param words: The list of words to vectorize. :return: A list of word vectors. """ sentence = Sentence(' '.join(words)) self.model.embed(sentence) return list( map(lambda token: np.array(token.embedding), list(sentence)) ) def vectorize_context(self, words: List[str]) -> Optional[np.ndarray]: """ Transforms the context into a single vector. May return None in extreme cases, e.g. if |words| is an empty list. :param words: List of tokens describing the context. :return: A single word vector or None. """ return self.mean_of_words(self.get_word_vectors(words))
class SentenceBertEmbedderSensor(SentenceSensor): def __init__(self, *pres): super().__init__(*pres) self.bert_embedding = BertEmbeddings() def forward( self, ) -> Any: self.bert_embedding.embed(self.fetch_value(self.sentence_value)) return None
def get_flair_bert_embeddings(words): # Experimental -- not tested from flair.embeddings import BertEmbeddings bert_embedding = BertEmbeddings('bert-base-multilingual-cased') sentence = Sentence(words) bert_embedding.embed(sentence) return (sentence)
def dump_bert_vecs(df, dump_dir): print("Getting BERT vectors...") embedding = BertEmbeddings('bert-base-uncased') word_counter = defaultdict(int) stop_words = set(stopwords.words('english')) stop_words.add("would") except_counter = 0 tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') for index, row in df.iterrows(): if index % 100 == 0: print("Finished sentences: " + str(index) + " out of " + str(len(df))) #all sentences are undercase now line = row["sentence"].lower() sentences = sent_tokenize(line) for sentence_ind, sent in enumerate(sentences): tokenized_text = tokenizer.tokenize(sent) if len(tokenized_text) > 512: print('sentence too long for Bert: truncating') sentence = Sentence(' '.join(sent[:512]), use_tokenizer=True) else: sentence = Sentence(sent, use_tokenizer=True) try: embedding.embed(sentence) except Exception as e: except_counter += 1 print("Exception Counter while getting BERT: ", except_counter, sentence_ind, index, e) print(sentence) continue for token_ind, token in enumerate(sentence): word = token.text word = word.translate( str.maketrans('', '', string.punctuation)) if word in stop_words or "/" in word or len(word) == 0: continue word_dump_dir = dump_dir + word os.makedirs(word_dump_dir, exist_ok=True) fname = word_dump_dir + "/" + str( word_counter[word]) + ".pkl" word_counter[word] += 1 vec = token.embedding.cpu().numpy() try: with open(fname, "wb") as handler: pickle.dump(vec, handler) except Exception as e: except_counter += 1 print("Exception Counter while dumping BERT: ", except_counter, sentence_ind, index, word, e)
def _collect_embeddings(self, embeddings): embedders = { "flair-forward": FlairEmbeddings("multi-forward"), "flair-backward": FlairEmbeddings("multi-backward"), "charm-forward": CharLMEmbeddings("news-forward"), "charm-backward": CharLMEmbeddings("news-backward"), "glove": WordEmbeddings("glove"), "bert-small": BertEmbeddings("bert-base-uncased"), "bert-large": BertEmbeddings("bert-large-uncased"), "elmo-small": ELMoEmbeddings("small"), "elmo-large": ELMoEmbeddings("original"), } return [embedders[embedding] for embedding in embeddings]
def __init__(self, gpu, bert_embeddings_dim, freeze_bert_embeddings=True): super(LayerBertEmbeddings, self).__init__(gpu) self.gpu = gpu self.bert_embeddings_dim = bert_embeddings_dim self.freeze_char_embeddings = freeze_bert_embeddings # self.bert = BertModel.from_pretrained("/home/jlfu/saved_pytorch_bert/en_base_uncased/model") # # for p in self.bert.parameters(): # # p.requires_grad = True # self.tokenizer = BertTokenizer.from_pretrained('/home/jlfu/saved_pytorch_bert/en_base_uncased/vocab.txt') # self.bert = BertModel.from_pretrained("/home/jlfu/model/cased_L-12_H-768_A-12/bert_model.ckpt.gz") # self.bert = BertModel.from_pretrained('bert-base-cased') self.bert = BertModel.from_pretrained( "/home/jlfu/model/cased_L-12_H-768_A-12/bert-base-cased.tar.gz") self.tokenizer = BertTokenizer.from_pretrained( '/home/jlfu/model/cased_L-12_H-768_A-12/vocab.txt') self.Wbert = nn.Linear(768, bert_embeddings_dim) self.output_dim = 768 # self.output_dim = 3072 self.bert_embedding = BertEmbeddings( "bert-base-cased" ) # bert-base-cased, bert-base-multilingual-cased
def construct_embeddings( self, embeddings: Sequence[str] ) -> List[flair.embeddings.TokenEmbeddings]: ret = list() for name in embeddings: emb_type, emb_name, o_file, w_file = FlairDataset.extract_embedding_configs( name) self.embeddings_type = emb_type.lower() if name.startswith('bert'): embs = BertEmbeddings(emb_name) self.bert_tokenizer = embs.tokenizer ret.append(embs) elif name.startswith('flair'): ret.append(FlairEmbeddings(emb_name)) elif name.startswith('elmo'): embs = ELMoEmbeddings(emb_name, options_file=o_file, weight_file=w_file) ret.append(embs) elif name.startswith('word'): ret.append(WordEmbeddings(emb_name)) else: raise ValueError('Invalid Embedding Type: "{}"'.format(name)) return ret
class BertEmbedding(EmbeddingBase): def __init__(self): self.model = BertEmbeddings( bert_model_or_path="bert-base-multilingual-cased") self.size = 3072 def _get_vector(self, sentence: Sentence) -> np.ndarray: res = np.zeros(self.size, dtype=np.float32) for token in sentence.tokens: vec = np.fromiter(token.embedding.tolist(), dtype=np.float32) vec = vec / np.linalg.norm(vec, ord=2) res += vec res /= len(sentence.tokens) return res def batcher(self, params, batch: List[List[str]]) -> np.ndarray: batch = [ Sentence(" ".join(sent)) if sent != [] else ['.'] for sent in batch ] embeddings = [] sentences = self.model.embed(batch) for sent in sentences: embeddings.append(self._get_vector(sent)) embeddings = np.vstack(embeddings) return embeddings def dim(self) -> int: return self.size
def __init__(self, *embeddings: str): print("May need a couple moments to instantiate...") self.embedding_stack = [] # Load correct Embeddings module for model_name_or_path in embeddings: if "bert" in model_name_or_path and "roberta" not in model_name_or_path: self.embedding_stack.append(BertEmbeddings(model_name_or_path)) elif "roberta" in model_name_or_path: self.embedding_stack.append( RoBERTaEmbeddings(model_name_or_path)) elif "gpt2" in model_name_or_path: self.embedding_stack.append( OpenAIGPT2Embeddings(model_name_or_path)) elif "xlnet" in model_name_or_path: self.embedding_stack.append( XLNetEmbeddings(model_name_or_path)) elif "xlm" in model_name_or_path: self.embedding_stack.append(XLMEmbeddings(model_name_or_path)) elif ("flair" in model_name_or_path or model_name_or_path in FLAIR_PRETRAINED_MODEL_NAMES): self.embedding_stack.append( FlairEmbeddings(model_name_or_path)) else: print( f"Corresponding flair embedding module not found for {model_name_or_path}" ) assert len(self.embedding_stack) != 0 self.stacked_embeddings = StackedEmbeddings( embeddings=self.embedding_stack)
class Bert(nn.Module): def __init__(self, idx2word, device=torch.device('cpu')): super(Bert, self).__init__() self.idx2word = idx2word self.embed_size = sizes["bert"] self.bert = BertEmbeddings('bert-base-uncased', '-2') def proc(self, string): if string == '.': return "[SEP]" if string == "__": return "[MASK]" return string def forward(self, batch): # TODO: fill this in batch_as_words = [[ self.proc(str(self.idx2word[token])) for token in l ] for l in batch.transpose(0, 1).tolist()] batch_as_sentences = [Sentence(' '.join(l)) for l in batch_as_words] embeds = self.bert.embed(batch_as_sentences) embeds = [[token.embedding for token in sentence] for sentence in embeds] return torch.stack([torch.stack(sentence) for sentence in embeds]).transpose(0, 1).cuda()
def generate_topics_on_series(series): """https://towardsdatascience.com/covid-19-with-a-flair-2802a9f4c90f Returns: [type]: [description] """ validate_text(series) # initialise embedding classes flair_embedding_forward = FlairEmbeddings("news-forward") flair_embedding_backward = FlairEmbeddings("news-backward") bert_embedding = BertEmbeddings("bert-base-uncased") # combine word embedding models document_embeddings = DocumentPoolEmbeddings( [bert_embedding, flair_embedding_backward, flair_embedding_forward]) # set up empty tensor X = torch.empty(size=(len(series.index), 7168)).cuda() # fill tensor with embeddings i = 0 for text in tqdm(series): sentence = Sentence(text) document_embeddings.embed(sentence) embedding = sentence.get_embedding() X[i] = embedding i += 1 X = X.cpu().detach().numpy() torch.cuda.empty_cache() return X
def transform(self, X: dt.Frame): X.replace([None, math.inf, -math.inf], self._repl_val) from flair.embeddings import WordEmbeddings, BertEmbeddings, DocumentPoolEmbeddings, Sentence if self.embedding_name in ["glove", "en"]: self.embedding = WordEmbeddings(self.embedding_name) elif self.embedding_name in ["bert"]: self.embedding = BertEmbeddings() self.doc_embedding = DocumentPoolEmbeddings([self.embedding]) output = [] X = X.to_pandas() text1_arr = X.iloc[:, 0].values text2_arr = X.iloc[:, 1].values for ind, text1 in enumerate(text1_arr): try: text1 = Sentence(str(text1).lower()) self.doc_embedding.embed(text1) text2 = text2_arr[ind] text2 = Sentence(str(text2).lower()) self.doc_embedding.embed(text2) score = cosine_similarity(text1.get_embedding().reshape(1, -1), text2.get_embedding().reshape(1, -1))[0, 0] output.append(score) except: output.append(-99) return np.array(output)
def __init__(self, pipeline): self.mode = pipeline.mode self.type = pipeline.embedding_type embedders = [] for component in pipeline.embedders: if "forward" in component or "backward" in component: embedders.append(FlairEmbeddings(component)) elif "glove" in component: embedders.append(WordEmbeddings(component)) elif "bert" in component: embedders.append(BertEmbeddings(component)) elif len(component) == 2: # see https://github.com/zalandoresearch/flair/blob/master/resources/docs/embeddings/FASTTEXT_EMBEDDINGS.md#fasttext-embeddings embedders.append(WordEmbeddings(component)) embedders.append(BytePairEmbeddings(component)) else: raise ValueError(f"unknown embedder: {component}") if self.type == "document": self.embedder = self._make_doc_embedder(pipeline, embedders) elif self.type == "word": self.embedder = StackedEmbeddings(embedders) elif self.type == "both": self.embedders = [ self._make_doc_embedder(pipeline, embedders), StackedEmbeddings(embedders), ] else: raise ValueError( f"Innapropriate embedding type {pipeline.embedding_type}, " "should be 'word', 'document', or 'both'.")
def _get_vectorizer(vectorizer, training_mode, features_file="features.pkl"): token_pattern = r'\S+' if training_mode: vocabulary = None else: if vectorizer not in [HashingVectorizer.__name__, DocumentPoolEmbeddings.__name__]: vocabulary = pickle_manager.load(features_file) if vectorizer == TfidfVectorizer.__name__: v = TfidfVectorizer(input='content', encoding='utf-8', decode_error='strict', strip_accents=None, lowercase=True, preprocessor=None, tokenizer=None, analyzer='word', stop_words=[], token_pattern=token_pattern, ngram_range=(1,1), max_df=1.0, min_df=1, max_features=None, vocabulary=vocabulary, binary=False, dtype=np.float64, norm='l2', use_idf=True, smooth_idf=True, sublinear_tf=False) elif vectorizer == CountVectorizer.__name__: v = CountVectorizer(input='content', encoding='utf-8', decode_error='strict', strip_accents=None, lowercase=True, preprocessor=None, tokenizer=None, stop_words=[], token_pattern=token_pattern, ngram_range=(1, 1), analyzer='word', max_df=1.0, min_df=1, max_features=None, vocabulary=vocabulary, binary=False, dtype=np.int64) elif vectorizer == HashingVectorizer.__name__: v = HashingVectorizer(input='content', encoding='utf-8', decode_error='strict', strip_accents=None, lowercase=True, preprocessor=None, tokenizer=None, stop_words=[], token_pattern=token_pattern, ngram_range=(1, 1), analyzer='word', n_features=1048576, binary=False, norm='l2', alternate_sign=True, non_negative=False, dtype=np.float64) elif vectorizer == DocumentPoolEmbeddings.__name__: v = DocumentPoolEmbeddings([BertEmbeddings('bert-base-multilingual-uncased')]) else: raise "Invalid vectorizer." return v
def compute_query_embedding_pretrained_models(query): """ learn bert sentence embeddings using pre-trained word embedding model for an arbitrary question :param query: String :: arbitrary sentence :return: n-dimensional embedding """ query = query up = os.path.abspath(os.path.dirname(__file__)) conf_path = os.path.join(up, 'config.ini') embedding_name = 'bert' transformer_model_or_path = get_config( 'Transformers', '{}_model_or_path'.format(embedding_name), conf_path, 0) transformer_layers = get_config('Transformers', '{}_layers'.format(embedding_name), conf_path, 0) transformer_pooling_operation = get_config( 'Transformers', '{}_pooling_operation'.format(embedding_name), conf_path, 0) transformer_use_scalar_mix = get_config( 'Transformers', '{}_use_scalar_mix'.format(embedding_name), conf_path, 1) word_embedding = BertEmbeddings( bert_model_or_path=transformer_model_or_path, layers=transformer_layers, pooling_operation=transformer_pooling_operation, use_scalar_mix=transformer_use_scalar_mix) document_embeddings = DocumentPoolEmbeddings([word_embedding], fine_tune_mode='none') query_embedding = compute_pretrained_individual_transformer_embedding( query, word_embedding, document_embeddings) return query + ',' + str(list(query_embedding))[1:-1]
def download_flair_models(): w = WordEmbeddings("en-crawl") w = WordEmbeddings("news") w = FlairEmbeddings("news-forward-fast") w = FlairEmbeddings("news-backward-fast") w = FlairEmbeddings("mix-forward") w = BertEmbeddings("bert-base-uncased")
def __init__(self, *embeddings: str): print("May need a couple moments to instantiate...") self.embedding_stack = [] # Load correct Embeddings module for model_name_or_path in embeddings: if "bert" in model_name_or_path and "roberta" not in model_name_or_path: self.embedding_stack.append(BertEmbeddings(model_name_or_path)) elif "roberta" in model_name_or_path: self.embedding_stack.append( RoBERTaEmbeddings(model_name_or_path)) elif "gpt2" in model_name_or_path: self.embedding_stack.append( OpenAIGPT2Embeddings(model_name_or_path)) elif "xlnet" in model_name_or_path: self.embedding_stack.append( XLNetEmbeddings(model_name_or_path)) elif "xlm" in model_name_or_path: self.embedding_stack.append(XLMEmbeddings(model_name_or_path)) elif ("flair" in model_name_or_path or model_name_or_path in FLAIR_PRETRAINED_MODEL_NAMES): self.embedding_stack.append( FlairEmbeddings(model_name_or_path)) else: try: self.embedding_stack.append( WordEmbeddings(model_name_or_path)) except ValueError: raise ValueError( f"Embeddings not found for the model key: {model_name_or_path}, check documentation or custom model path to verify specified model" ) assert len(self.embedding_stack) != 0 self.stacked_embeddings = StackedEmbeddings( embeddings=self.embedding_stack)
def create_embeddings(params): embedding_type = params["embedding_type"] assert embedding_type in ["bert", "flair", "char"] if embedding_type == "bert": bert_embedding = BertEmbeddings(params["bert_model_dirpath_or_name"], pooling_operation="mean") embedding_types: List[TokenEmbeddings] = [bert_embedding] embeddings: StackedEmbeddings = StackedEmbeddings( embeddings=embedding_types) elif embedding_type == "flair": glove_embedding = WordEmbeddings( '/opt/kanarya/glove/GLOVE/GloVe/vectors.gensim') word2vec_embedding = WordEmbeddings( '/opt/kanarya/huawei_w2v/vector.gensim') fast_text_embedding = WordEmbeddings('tr') char_embedding = CharacterEmbeddings() # bert_embedding = BertEmbeddings('../bert_pretraining/pretraining_outputs/pretraining_output_batch_size_32') embedding_types: List[TokenEmbeddings] = [ fast_text_embedding, glove_embedding, word2vec_embedding, char_embedding ] # embedding_types: List[TokenEmbeddings] = [custom_embedding] embeddings: StackedEmbeddings = StackedEmbeddings( embeddings=embedding_types) elif embedding_type == "char": embeddings: StackedEmbeddings = StackedEmbeddings( embeddings=[CharacterEmbeddings()]) else: embeddings = None return embeddings
def get_Bert_embeddings(vocab, dim): from flair.embeddings import BertEmbeddings from flair.data import Sentence _embeddings = np.zeros([len(vocab), dim]) temp = [] for each_word in vocab: temp.append(each_word) sentence = Sentence(' '.join(temp)) embedding = BertEmbeddings() embedding.embed(sentence) for token in sentence: _embeddings[vocab[token.text]] = token.embedding return _embeddings
def get_bert(model_name, layers='-1,-2,-3,-4', pooling_op='first', scalar_mix=False): return BertEmbeddings(model_name, layers=layers, pooling_operation=pooling_op, use_scalar_mix=scalar_mix)
def __init__(self, MAX_WORD_N=150, MAX_SENT_N=30, MAX_WORD_SENT_N=300, alber_model="albert-base-v2") -> None: super().__init__() albert = BertEmbeddings(bert_model_or_path=alber_model) self.albert_embedding = DocumentPoolEmbeddings([albert]) self.MAX_WORD_N = MAX_WORD_N self.MAX_SENT_N = MAX_SENT_N self.MAX_WORD_SENT_N = MAX_WORD_SENT_N self.sentence_piecer = MySentencePiecer()
def args_init(args): # initialize word2vec args.word2vec = KeyedVectors.load_word2vec_format('data/mymodel-new-5-%d' % args.model_dim, binary=True) # initialize contextual embedding dimensions if args.contextual_embedding == 'word2vec': args.word_dim = args.tag_dim = args.dis_dim = 50 args.stacked_embeddings = 'word2vec' elif args.contextual_embedding == 'elmo': #glove + elmo args.word_dim = args.tag_dim = args.dis_dim = 868 ## stacked embeddings # create a StackedEmbedding object that combines glove and forward/backward flair embeddings args.stacked_embeddings = StackedEmbeddings( [WordEmbeddings('glove'), ELMoEmbeddings('small')]) elif args.contextual_embedding == 'bert': #glove + bert args.word_dim = args.tag_dim = args.dis_dim = 3172 args.stacked_embeddings = StackedEmbeddings( [WordEmbeddings('glove'), BertEmbeddings('bert-base-uncased')]) args.batch_size = 8 elif args.contextual_embedding == 'flair': #glove + flair-forward + flair-backward args.word_dim = args.tag_dim = args.dis_dim = 4196 args.stacked_embeddings = StackedEmbeddings([ WordEmbeddings('glove'), FlairEmbeddings('mix-forward', chars_per_chunk=128), FlairEmbeddings('mix-backward', chars_per_chunk=128) ]) if args.agent_mode == 'act': args.batch_size = 8 else: args.batch_size = 8 elif args.contextual_embedding == 'glove': # not tested args.word_dim = args.tag_dim = args.dis_dim = 100 args.stacked_embeddings = StackedEmbeddings([ WordEmbeddings('glove'), ]) # weights loaded, set exploration rate to minimum if args.load_weights: # 1 to 0.1. decayed to minimum. args.exploration_rate_start = args.exploration_rate_end # agent mode arguments, set number of words to 100 if args.agent_mode == 'arg': args.num_words = args.context_len args.display_training_result = 0 args.result_dir = 'results/%s_%s_%s' % (args.domain, args.agent_mode, args.contextual_embedding) return args
def collect_features(embeddings): for embedding in embeddings: if embedding in {"fasttext"}: yield WordEmbeddings("de") elif embedding in {"bert"}: yield BertEmbeddings("bert-base-multilingual-cased", layers="-1") elif embedding in {"flair-forward"}: yield FlairEmbeddings("german-forward") elif embedding in {"flair-backward"}: yield FlairEmbeddings("german-backward")
def load_flair(mode = 'flair'): if mode == 'flair': stacked_embeddings = StackedEmbeddings([ WordEmbeddings('glove'), PooledFlairEmbeddings('news-forward', pooling='min'), PooledFlairEmbeddings('news-backward', pooling='min') ]) else:##bert stacked_embeddings = BertEmbeddings('bert-base-uncased') ##concat last 4 layers give the best return stacked_embeddings
def get_flair_class(embed_type): """ Return the correct flair class for the embed type """ if embed_type == 'elmo': fl_embed = ELMoEmbeddings() elif embed_type == 'bert': fl_embed = BertEmbeddings() else: fl_embed = FlairEmbeddings(embed_type) return fl_embed