def test_document_bidirectional_lstm_embeddings(): (sentence, glove, charlm) = init_document_embeddings() embeddings = DocumentRNNEmbeddings( [glove, charlm], hidden_size=128, bidirectional=True) embeddings.embed(sentence) assert (len(sentence.get_embedding()) == 512) assert (len(sentence.get_embedding()) == embeddings.embedding_length) sentence.clear_embeddings() assert (len(sentence.get_embedding()) == 0)
def __init__( self, *embeddings: str, methods: List[str] = ["rnn", "pool"], configs: Dict = { "pool_configs": { "fine_tune_mode": "linear", "pooling": "mean" }, "rnn_configs": { "hidden_size": 512, "rnn_layers": 1, "reproject_words": True, "reproject_words_dimension": 256, "bidirectional": False, "dropout": 0.5, "word_dropout": 0.0, "locked_dropout": 0.0, "rnn_type": "GRU", "fine_tune": True, }, }, ): print("May need a couple moments to instantiate...") self.embedding_stack = [] # Check methods for m in methods: assert m in self.__class__.__allowed_methods # Set configs for pooling and rnn parameters for k, v in configs.items(): assert k in self.__class__.__allowed_configs setattr(self, k, v) # Load correct Embeddings module for model_name_or_path in embeddings: self.embedding_stack.append( _get_embedding_model(model_name_or_path)) assert len(self.embedding_stack) != 0 if "pool" in methods: self.pool_embeddings = DocumentPoolEmbeddings( self.embedding_stack, **self.pool_configs) print("Pooled embedding loaded") if "rnn" in methods: self.rnn_embeddings = DocumentRNNEmbeddings( self.embedding_stack, **self.rnn_configs) print("RNN embeddings loaded")
def test_keep_batch_order(): embeddings = DocumentRNNEmbeddings([glove]) sentences_1 = [Sentence("First sentence"), Sentence("This is second sentence")] sentences_2 = [Sentence("This is second sentence"), Sentence("First sentence")] embeddings.embed(sentences_1) embeddings.embed(sentences_2) assert sentences_1[0].to_original_text() == "First sentence" assert sentences_1[1].to_original_text() == "This is second sentence" assert torch.norm(sentences_1[0].embedding - sentences_2[1].embedding) == 0.0 assert torch.norm(sentences_1[0].embedding - sentences_2[1].embedding) == 0.0 del embeddings
def test_keep_batch_order(): (sentence, glove, charlm) = init_document_embeddings() embeddings = DocumentRNNEmbeddings([glove]) sentences_1 = [Sentence('First sentence'), Sentence('This is second sentence')] sentences_2 = [Sentence('This is second sentence'), Sentence('First sentence')] embeddings.embed(sentences_1) embeddings.embed(sentences_2) assert (sentences_1[0].to_original_text() == 'First sentence') assert (sentences_1[1].to_original_text() == 'This is second sentence') assert (torch.norm( (sentences_1[0].embedding - sentences_2[1].embedding)) == 0.0) assert (torch.norm( (sentences_1[0].embedding - sentences_2[1].embedding)) == 0.0)
def test_train_resume_text_classification_training(results_base_path, tasks_base_path): corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "imdb") label_dict = corpus.make_label_dictionary() embeddings: TokenEmbeddings = FlairEmbeddings("news-forward-fast", use_cache=False) document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings( [embeddings], 128, 1, False) model = TextClassifier(document_embeddings, label_dict, False) trainer = ModelTrainer(model, corpus) trainer.train(results_base_path, max_epochs=2, shuffle=False, checkpoint=True) checkpoint = TextClassifier.load_checkpoint(results_base_path / "checkpoint.pt") trainer = ModelTrainer.load_from_checkpoint(checkpoint, corpus) trainer.train(results_base_path, max_epochs=2, shuffle=False, checkpoint=True) # clean up results directory shutil.rmtree(results_base_path)
def train(): corpus: Corpus = ClassificationCorpus(sst_folder, test_file='test.csv', dev_file='dev.csv', train_file='sst_dev.csv') label_dict = corpus.make_label_dictionary() stacked_embedding = WordEmbeddings('glove') # Stack Flair string-embeddings with optional embeddings word_embeddings = list( filter(None, [ stacked_embedding, FlairEmbeddings('news-forward-fast'), FlairEmbeddings('news-backward-fast'), ])) # Initialize document embedding by passing list of word embeddings document_embeddings = DocumentRNNEmbeddings( word_embeddings, hidden_size=512, reproject_words=True, reproject_words_dimension=256, ) # Define classifier classifier = TextClassifier(document_embeddings, label_dictionary=label_dict, multi_label=False) trainer = ModelTrainer(classifier, corpus) trainer.train(model_path, max_epochs=10, train_with_dev=False)
def train(): # Get the SST-5 corpus corpus: Corpus = SENTEVAL_SST_GRANULAR() # create the label dictionary label_dict = corpus.make_label_dictionary() # make a list of word embeddings ( Using Glove for testing ) word_embeddings = [WordEmbeddings('glove')] # initialize document embedding by passing list of word embeddings document_embeddings = DocumentRNNEmbeddings(word_embeddings, hidden_size=256) # create the text classifier classifier = TextClassifier(document_embeddings, label_dictionary=label_dict) # initialize the text classifier trainer trainer = ModelTrainer(classifier, corpus) # start the training trainer.train('resources/taggers/trec', learning_rate=0.1, mini_batch_size=32, anneal_factor=0.5, patience=5, embeddings_storage_mode='gpu', max_epochs=15)
def optimize_lr(): corpus, label_dictionary = load_corpus() embeddings = [ WordEmbeddings('glove'), FlairEmbeddings('news-forward'), FlairEmbeddings('news-backward') ] document_embeddings = DocumentRNNEmbeddings(embeddings, hidden_size=512, reproject_words=True, reproject_words_dimension=256, bidirectional=True) classifier = TextClassifier(document_embeddings, label_dictionary=label_dictionary, multi_label=False) trainer = ModelTrainer(classifier, corpus) # 7. find learning rate learning_rate_tsv = trainer.find_learning_rate('resources/classifiers/', 'learning_rate.tsv') # 8. plot the learning rate finder curve from flair.visual.training_curves import Plotter plotter = Plotter() plotter.plot_learning_rate(learning_rate_tsv)
def train_model(data_dir, max_epochs): st.write('Creating word corpus for training...') corpus = ClassificationCorpus(data_dir) label_dict = corpus.make_label_dictionary() st.write('Done') st.write('Load and create Embeddings for text data...') word_embeddings = [ WordEmbeddings('glove'), # FlairEmbeddings('news-forward'), # FlairEmbeddings('news-backward') ] document_embeddings = DocumentRNNEmbeddings(word_embeddings, hidden_size=512, reproject_words=True, reproject_words_dimension=256) st.write('Done') st.write('Preparing') classifier = TextClassifier(document_embeddings, label_dictionary=label_dict) trainer = ModelTrainer(classifier, corpus) trainer.train('model-saves', learning_rate=0.1, mini_batch_size=32, anneal_factor=0.5, patience=8, max_epochs=max_epochs, checkpoint=True) st.write('Model Training Finished!')
def test_train_charlm_load_use_classifier(results_base_path, tasks_base_path): corpus = NLPTaskDataFetcher.load_corpus("imdb", base_path=tasks_base_path) label_dict = corpus.make_label_dictionary() embedding: TokenEmbeddings = FlairEmbeddings("news-forward-fast") document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings( [embedding], 128, 1, False, 64, False, False ) model = TextClassifier(document_embeddings, label_dict, False) trainer = ModelTrainer(model, corpus) trainer.train( results_base_path, EvaluationMetric.MACRO_F1_SCORE, max_epochs=2, test_mode=True ) sentence = Sentence("Berlin is a really nice city.") for s in model.predict(sentence): for l in s.labels: assert l.value is not None assert 0.0 <= l.score <= 1.0 assert type(l.score) is float loaded_model = TextClassifier.load_from_file(results_base_path / "final-model.pt") sentence = Sentence("I love Berlin") sentence_empty = Sentence(" ") loaded_model.predict(sentence) loaded_model.predict([sentence, sentence_empty]) loaded_model.predict([sentence_empty]) # clean up results directory shutil.rmtree(results_base_path)
def test_train_load_use_classifier_with_prob(results_base_path, tasks_base_path): corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "imdb") label_dict = corpus.make_label_dictionary() word_embedding: WordEmbeddings = WordEmbeddings("turian") document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings( [word_embedding], 128, 1, False, 64, False, False) model = TextClassifier(document_embeddings, label_dict, False) trainer = ModelTrainer(model, corpus) trainer.train(results_base_path, max_epochs=2, shuffle=False) sentence = Sentence("Berlin is a really nice city.") for s in model.predict(sentence, multi_class_prob=True): for l in s.labels: assert l.value is not None assert 0.0 <= l.score <= 1.0 assert type(l.score) is float loaded_model = TextClassifier.load(results_base_path / "final-model.pt") sentence = Sentence("I love Berlin") sentence_empty = Sentence(" ") loaded_model.predict(sentence, multi_class_prob=True) loaded_model.predict([sentence, sentence_empty], multi_class_prob=True) loaded_model.predict([sentence_empty], multi_class_prob=True) # clean up results directory shutil.rmtree(results_base_path)
def test_train_resume_text_classification_training(results_base_path, tasks_base_path): corpus = NLPTaskDataFetcher.load_corpus('imdb', base_path=tasks_base_path) label_dict = corpus.make_label_dictionary() embeddings: TokenEmbeddings = FlairEmbeddings('news-forward-fast', use_cache=False) document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings( [embeddings], 128, 1, False) model = TextClassifier(document_embeddings, label_dict, False) trainer = ModelTrainer(model, corpus) trainer.train(results_base_path, max_epochs=2, test_mode=True, checkpoint=True) trainer = ModelTrainer.load_from_checkpoint( results_base_path / 'checkpoint.pt', 'TextClassifier', corpus) trainer.train(results_base_path, max_epochs=2, test_mode=True, checkpoint=True) # clean up results directory shutil.rmtree(results_base_path)
def test_train_charlm_nocache_load_use_classifier(results_base_path, tasks_base_path): corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "imdb") label_dict = corpus.make_label_dictionary() embedding: TokenEmbeddings = FlairEmbeddings("news-forward-fast", use_cache=False) document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings( [embedding], 128, 1, False, 64, False, False) model = TextClassifier(document_embeddings, label_dict, False) trainer = ModelTrainer(model, corpus) trainer.train(results_base_path, max_epochs=2, shuffle=False) sentence = Sentence("Berlin is a really nice city.") for s in model.predict(sentence): for l in s.labels: assert l.value is not None assert 0.0 <= l.score <= 1.0 assert type(l.score) is float loaded_model = TextClassifier.load(results_base_path / "final-model.pt") sentence = Sentence("I love Berlin") sentence_empty = Sentence(" ") loaded_model.predict(sentence) loaded_model.predict([sentence, sentence_empty]) loaded_model.predict([sentence_empty]) # clean up results directory shutil.rmtree(results_base_path)
def test_train_load_use_classifier_multi_label(results_base_path, tasks_base_path): # corpus = NLPTaskDataFetcher.load_corpus('multi_class', base_path=tasks_base_path) corpus = NLPTaskDataFetcher.load_classification_corpus( data_folder=tasks_base_path / "multi_class" ) label_dict = corpus.make_label_dictionary() word_embedding: WordEmbeddings = WordEmbeddings("turian") document_embeddings = DocumentRNNEmbeddings( embeddings=[word_embedding], hidden_size=32, reproject_words=False, bidirectional=False, ) model = TextClassifier(document_embeddings, label_dict, multi_label=True) trainer = ModelTrainer(model, corpus) trainer.train( results_base_path, EvaluationMetric.MICRO_F1_SCORE, mini_batch_size=1, max_epochs=100, test_mode=True, checkpoint=False, ) sentence = Sentence("apple tv") for s in model.predict(sentence): for l in s.labels: print(l) assert l.value is not None assert 0.0 <= l.score <= 1.0 assert type(l.score) is float sentence = Sentence("apple tv") for s in model.predict(sentence): assert "apple" in sentence.get_label_names() assert "tv" in sentence.get_label_names() for l in s.labels: print(l) assert l.value is not None assert 0.0 <= l.score <= 1.0 assert type(l.score) is float loaded_model = TextClassifier.load_from_file(results_base_path / "final-model.pt") sentence = Sentence("I love Berlin") sentence_empty = Sentence(" ") loaded_model.predict(sentence) loaded_model.predict([sentence, sentence_empty]) loaded_model.predict([sentence_empty]) # clean up results directory shutil.rmtree(results_base_path)
def train(args): """Train.""" start_time = time.time() column_format = {i: col for i, col in enumerate(args.data_columns)} corpus: Corpus = ClassColumnCorpus( args.data_dir, column_format, train_file=args.train_file, dev_file=args.dev_file, comment_symbol=args.comment_symbol, label_symbol=args.label_symbol, ) tag_type = args.data_columns[-1] tag_dict = corpus.make_tag_dictionary(tag_type=tag_type) label_dict = corpus.make_label_dictionary() vocab = corpus.make_vocab_dictionary().get_items() embeddings = utils.init_embeddings(vocab, args) model1: SequenceTagger = SequenceTagger( hidden_size=args.hidden_size, embeddings=embeddings, tag_dictionary=tag_dict, tag_type=tag_type, column_format=column_format, use_crf=True, use_attn=args.use_attn, attn_type=args.attn_type, num_heads=args.num_heads, scaling=args.scaling, pooling_operation=args.pooling_operation, use_sent_query=args.use_sent_query, ) document_embeddings = DocumentRNNEmbeddings( [embeddings], hidden_size=args.hidden_size, ) model2 = TextClassifier(document_embeddings, label_dictionary=label_dict) utils.init_joint_models(model1, model2, args) trainer: JointModelTrainer = JointModelTrainer( model1, model2, corpus, utils.optim_method(args.optim) ) trainer.train( args.model_dir, mini_batch_size=args.mini_batch_size, max_epochs=args.max_epochs, anneal_factor=args.anneal_factor, learning_rate=args.learning_rate, patience=args.patience, min_learning_rate=args.min_learning_rate, embeddings_storage_mode=args.embeddings_storage_mode, gamma=args.gamma, ) logger.info("End of training: time %.1f min", (time.time() - start_time) / 60)
def init(tasks_base_path) -> Tuple[(Corpus, Dictionary, TextClassifier)]: corpus = flair.datasets.ClassificationCorpus((tasks_base_path / 'ag_news')) label_dict = corpus.make_label_dictionary() glove_embedding = WordEmbeddings('turian') document_embeddings = DocumentRNNEmbeddings([glove_embedding], 128, 1, False, 64, False, False) model = TextClassifier(document_embeddings, label_dict, False) return (corpus, label_dict, model)
def init(tasks_base_path) -> Tuple[(Corpus, TextRegressor, ModelTrainer)]: corpus = NLPTaskDataFetcher.load_corpus(NLPTask.REGRESSION, tasks_base_path) glove_embedding = WordEmbeddings('glove') document_embeddings = DocumentRNNEmbeddings([glove_embedding], 128, 1, False, 64, False, False) model = TextRegressor(document_embeddings) trainer = ModelTrainer(model, corpus) return (corpus, model, trainer)
def set_up_document_RNNEmbedding(self, hidden_size=512, reproject_words=True, reproject_words_dimension=256): self.document_RNNEmbeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings( self.word_embeddings, hidden_size=hidden_size, reproject_words=reproject_words, reproject_words_dimension=reproject_words_dimension, )
def init(tasks_base_path) -> Tuple[TaggedCorpus, Dictionary, TextClassifier]: corpus = NLPTaskDataFetcher.load_corpus(NLPTask.AG_NEWS, tasks_base_path) label_dict = corpus.make_label_dictionary() glove_embedding: WordEmbeddings = WordEmbeddings('turian') document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings( [glove_embedding], 128, 1, False, 64, False, False) model = TextClassifier(document_embeddings, label_dict, False) return corpus, label_dict, model
def train(self): corpus = NLPTaskDataFetcher.load_classification_corpus(Path(self.corpus_path), test_file="test_clean_text.txt", dev_file="dev_clean_text.txt", train_file="train_clean_text.txt") embeddings = [WordEmbeddings(self.word_emb_path), FlairEmbeddings('polish-forward'), FlairEmbeddings('polish-backward')] document_embeddings = DocumentRNNEmbeddings(embeddings, hidden_size=self.hidden_size, bidirectional=True) classifier = TextClassifier(document_embeddings, label_dictionary=corpus.make_label_dictionary(), multi_label=False) trainer = ModelTrainer(classifier, corpus) trainer.train(self.model_path, evaluation_metric=EvaluationMetric.MACRO_F1_SCORE, max_epochs=self.epochs)
def init(tasks_base_path) -> Tuple[Corpus, Dictionary, TextClassifier]: # get training, test and dev data corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "ag_news") label_dict = corpus.make_label_dictionary() glove_embedding: WordEmbeddings = WordEmbeddings("turian") document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings( [glove_embedding], 128, 1, False, 64, False, False) model = TextClassifier(document_embeddings, label_dict, multi_label=False) return corpus, label_dict, model
def train(args): """Train.""" start_time = time.time() if args.one_per_line: corpus: Corpus = ClassificationCorpus( args.data_dir, train_file=args.train_file, dev_file=args.dev_file, ) else: assert args.label_symbol is not None corpus: Corpus = FlyClassificationCorpus( args.data_dir, train_file=args.train_file, dev_file=args.dev_file, comment_symbol=args.comment_symbol, label_symbol=args.label_symbol, ) label_dict = corpus.make_label_dictionary() vocab = corpus.make_vocab_dictionary().get_items() embeddings = utils.init_embeddings(vocab, args) document_embeddings = DocumentRNNEmbeddings( [embeddings], hidden_size=args.hidden_size, use_attn=args.use_attn, num_heads=args.num_heads, scaling=args.scaling, pooling_operation=args.pooling_operation, use_sent_query=args.use_sent_query, ) model = TextClassifier(document_embeddings, label_dictionary=label_dict) utils.init_model(model, args) trainer: ModelTrainer = ModelTrainer(model, corpus, utils.optim_method(args.optim)) trainer.train( args.model_dir, mini_batch_size=args.mini_batch_size, max_epochs=args.max_epochs, anneal_factor=args.anneal_factor, learning_rate=args.learning_rate, patience=args.patience, min_learning_rate=args.min_learning_rate, embeddings_storage_mode=args.embeddings_storage_mode, ) logger.info("End of training: time %.1f min", (time.time() - start_time) / 60)
def init(tasks_base_path) -> Tuple[TaggedCorpus, TextRegressor]: corpus = NLPTaskDataFetcher.load_corpus(NLPTask.REGRESSION, tasks_base_path) glove_embedding: WordEmbeddings = WordEmbeddings("glove") document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings( [glove_embedding], 128, 1, False, 64, False, False) model = TextRegressor(document_embeddings, Dictionary(), False) trainer = RegressorTrainer(model, corpus) return corpus, model, trainer
def init(tasks_base_path) -> Tuple[Corpus, TextRegressor, ModelTrainer]: corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "regression") glove_embedding: WordEmbeddings = WordEmbeddings("glove") document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings( [glove_embedding], 128, 1, False, 64, False, False) model = TextRegressor(document_embeddings) trainer = ModelTrainer(model, corpus) return corpus, model, trainer
def test_document_lstm_embeddings(): sentence, glove, charlm = init_document_embeddings() embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings( [glove, charlm], hidden_size=128, bidirectional=False) embeddings.embed(sentence) assert len(sentence.get_embedding()) == 128 assert len(sentence.get_embedding()) == embeddings.embedding_length sentence.clear_embeddings() assert len(sentence.get_embedding()) == 0
def _set_up_model(self, params: dict): embdding_params = { key: params[key] for key in params if (key in DOCUMENT_EMBEDDING_PARAMETERS) } if (self.document_embedding_type == 'lstm'): document_embedding = DocumentRNNEmbeddings(**embdding_params) else: document_embedding = DocumentPoolEmbeddings(**embdding_params) text_classifier = TextClassifier( label_dictionary=self.label_dictionary, multi_label=self.multi_label, document_embeddings=document_embedding) return text_classifier
def classify(data, labels, test, train, validation): train_data = [k for k in data.keys() if k in train] train_labels = [labels[k] for k in train_data] train_data = [data[k] for k in train_data] test_data = [k for k in data.keys() if k in test] test_labels = [labels[k] for k in test_data] test_data = [data[k] for k in test_data] validation_data = [k for k in data.keys() if k in validation] validation_labels = [labels[k] for k in validation_data] validation_data = [data[k] for k in validation_data] save_training_files(train_data, train_labels, test_data, test_labels, validation_data, validation_labels) corpus = NLPTaskDataFetcher.load_classification_corpus( Path('./'), test_file='test.txt', dev_file='dev.txt', train_file='train.txt') word_embeddings = [ WordEmbeddings('pl'), FlairEmbeddings('polish-forward'), FlairEmbeddings('polish-backward') ] doc_embeddings = DocumentRNNEmbeddings(word_embeddings, hidden_size=512, reproject_words=True, reproject_words_dimension=256) classifier = TextClassifier( doc_embeddings, label_dictionary=corpus.make_label_dictionary(), multi_label=False) trainer = ModelTrainer(classifier, corpus) trainer.train('./', max_epochs=25) classifier = TextClassifier.load_from_file('./best-model.pt') validation_data = [Sentence(x) for x in validation_data] for x in validation_data: classifier.predict(x) predicted = [int(x.labels[0].value) for x in validation_data] remove_training_files() precision, recall, f1, _ = precision_recall_fscore_support( validation_labels, predicted, average='binary') return { 'accuracy': float("{:.3f}".format(round(precision, 3))), 'recall': float("{:.3f}".format(round(recall, 3))), 'f1': float("{:.3f}".format(round(f1, 3))) }
def test_document_bidirectional_lstm_embeddings(): sentence: Sentence = Sentence( "I love Berlin. Berlin is a great place to live.") embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings( [glove, flair_embedding], hidden_size=128, bidirectional=True) embeddings.embed(sentence) assert len(sentence.get_embedding()) == 512 assert len(sentence.get_embedding()) == embeddings.embedding_length sentence.clear_embeddings() assert len(sentence.get_embedding()) == 0 del embeddings
def _set_up_model(self, params: dict): if self.model_type == SequenceTagger: parameter_set = SEQUENCE_TAGGER_PARAMETERS elif self.model_type == TextClassifier: parameter_set = DOCUMENT_EMBEDDING_PARAMETERS model_params = { key: params[key] for key in params if key in parameter_set } if self.model_type == SequenceTagger: self.tag_type = tag_type self.tag_dictionary = self.corpus.make_tag_dictionary( self.tag_type) model: SequenceTagger = SequenceTagger( tag_dictionary=self.tag_dictionary, tag_type=self.tag_type, **model_params, ) elif self.model_type == TextClassifier: self.multi_label = multi_label self.document_embedding_type = document_embedding_type self.label_dictionary = self.corpus.make_label_dictionary() if self.document_embedding_type == "lstm": document_embedding = DocumentRNNEmbeddings(**model_params) else: document_embedding = DocumentPoolEmbeddings(**model_params) model: TextClassifier = TextClassifier( label_dictionary=self.label_dictionary, multi_label=self.multi_label, document_embeddings=document_embedding, ) else: log.error("Unknown class type for parameter selection") raise TypeError # We bind _set_up_model method to the appropriate ParamSelector class # specified by model_type variable param_selector._set_up_model = types.MethodType( _set_up_model, param_selector) return param_selector
def _set_up_model(self, params: dict, label_dictionary): document_embedding = params['document_embeddings'].__name__ if document_embedding == "DocumentRNNEmbeddings": embedding_params = { key: params[key] for key, value in params.items() if key in DOCUMENT_RNN_EMBEDDING_PARAMETERS } embedding_params['embeddings'] = [ WordEmbeddings(TokenEmbedding) if type(params['embeddings']) == list else WordEmbeddings(params['embeddings']) for TokenEmbedding in params['embeddings'] ] document_embedding = DocumentRNNEmbeddings(**embedding_params) elif document_embedding == "DocumentPoolEmbeddings": embedding_params = { key: params[key] for key, value in params.items() if key in DOCUMENT_POOL_EMBEDDING_PARAMETERS } embedding_params['embeddings'] = [ WordEmbeddings(TokenEmbedding) for TokenEmbedding in params['embeddings'] ] document_embedding = DocumentPoolEmbeddings(**embedding_params) elif document_embedding == "TransformerDocumentEmbeddings": embedding_params = { key: params[key] for key, value in params.items() if key in DOCUMENT_TRANSFORMER_EMBEDDING_PARAMETERS } document_embedding = TransformerDocumentEmbeddings( **embedding_params) else: raise Exception("Please provide a flair document embedding class") text_classifier: TextClassifier = TextClassifier( label_dictionary=label_dictionary, multi_label=self.multi_label, document_embeddings=document_embedding, ) return text_classifier