class TextEncoder():
    valid_types = [
        'tfidf', 'count', 'binary', 'lda', 'doc2vec', 'bert_avg', 'fasttext'
    ]

    def __init__(self,
                 encoding_type,
                 engine_id,
                 max_vocab_size=10000,
                 min_n_gram=1,
                 max_n_gram=2,
                 num_of_topics=64,
                 encoding_size=64,
                 vectorizer_epochs=3,
                 max_page_size=10000,
                 fasttext_algorithm='skipgram',
                 tokenizer_level='word',
                 max_df=1.0):

        assert encoding_type in self.valid_types

        self.engine_id = engine_id
        self.id = str(uuid.uuid4())
        self.save_file_loc = f'{dir_loc}/text_analysis_results/models/{engine_id}_{self.id}_{encoding_type}_encoder'
        self.fasttext_training_file_location = f'{dir_loc}/text_analysis_results/fasttext/{engine_id}_{self.id}_{encoding_type}'

        self.encoding_type = encoding_type
        self.max_vocab_size = max_vocab_size
        self.min_n_gram = min_n_gram
        self.max_n_gram = max_n_gram
        self.tokenizer_level = tokenizer_level
        self.max_df = max_df
        self.num_of_topics = num_of_topics
        self.encoding_size = encoding_size
        self.common_dictionary = None
        self.vectorizer_epochs = vectorizer_epochs
        self.fasttext_algorithm = fasttext_algorithm
        self.max_page_size = max_page_size

    def fit(self, documents):
        documents = [tokenize(d) for d in documents]
        documents = [d[:self.max_page_size] for d in documents]
        documents = [' '.join(d) for d in documents]

        if self.encoding_type in ['tfidf', 'count', 'binary']:

            if self.encoding_type == 'tfidf':
                self.vectorizer = CountVectorizer(
                    ngram_range=(self.min_n_gram, self.max_n_gram),
                    max_features=self.max_vocab_size,
                    binary=False,
                    max_df=self.max_df,
                    analyzer=self.tokenizer_level)
                self.vectorizer.fit(documents)
            if self.encoding_type == 'count':
                self.vectorizer = CountVectorizer(
                    ngram_range=(self.min_n_gram, self.max_n_gram),
                    max_features=self.max_vocab_size,
                    binary=False,
                    max_df=self.max_df,
                    analyzer=self.tokenizer_level)
                self.vectorizer.fit(documents)
            if self.encoding_type == 'binary':
                self.vectorizer = CountVectorizer(
                    ngram_range=(self.min_n_gram, self.max_n_gram),
                    max_features=self.max_vocab_size,
                    binary=False,
                    max_df=self.max_df,
                    analyzer=self.tokenizer_level)
                self.vectorizer.fit(documents)
            with open(self.save_file_loc, 'wb') as f:
                pickle.dump(self.vectorizer, f)
        if self.encoding_type == 'lda':
            documents_tokenized = [tokenize(i) for i in documents]
            self.common_dictionary = Dictionary(documents_tokenized)
            common_corpus = [
                self.common_dictionary.doc2bow(text)
                for text in documents_tokenized
            ]
            self.vectorizer = ldamodel.LdaModel(common_corpus,
                                                id2word=self.common_dictionary,
                                                num_topics=self.num_of_topics,
                                                passes=self.vectorizer_epochs)
            self.vectorizer.save(self.save_file_loc)
        if self.encoding_type == 'doc2vec':
            tagged_documents = [
                TaggedDocument(tokenize(doc), [i])
                for i, doc in enumerate(documents)
            ]
            self.vectorizer = Doc2Vec(tagged_documents,
                                      vector_size=self.encoding_size,
                                      window=2,
                                      min_count=1,
                                      workers=4,
                                      epochs=self.vectorizer_epochs,
                                      max_vocab_size=100000)
            self.vectorizer.delete_temporary_training_data(
                keep_doctags_vectors=True, keep_inference=True)
            self.vectorizer.save(self.save_file_loc)
        if self.encoding_type == 'fasttext':
            with open(self.fasttext_training_file_location, 'w') as f:
                for i in documents:
                    f.write(clean_text(i) + '\n')
            self.vectorizer = fasttext.train_unsupervised(
                self.fasttext_training_file_location,
                model=self.fasttext_algorithm,
                dim=self.encoding_size)
            self.vectorizer.save_model(self.save_file_loc)

    def transform(self, documents):

        documents = [tokenize(d) for d in documents]
        documents = [d[:self.max_page_size] for d in documents]
        documents = [' '.join(d) for d in documents]

        if self.encoding_type in ['tfidf', 'count', 'binary']:
            return self.vectorizer.transform(documents).toarray()
        if self.encoding_type == 'lda':
            documents_tokenized = [tokenize(i) for i in documents]
            other_corpus = [
                self.common_dictionary.doc2bow(i) for i in documents_tokenized
            ]
            results = []
            for i in other_corpus:
                result = self.vectorizer[i]
                result = vectorize_topic_models(result, self.num_of_topics)
                results.append(result)

            return np.array(results)
        if self.encoding_type in ['doc2vec']:
            documents_tokenized = [tokenize(i) for i in documents]

            results = []
            for i in documents_tokenized:
                if i:
                    try:
                        results.append(self.vectorizer[i][0])
                    except KeyError:
                        results.append([0 for _ in range(self.encoding_size)])
                else:
                    results.append([0 for _ in range(self.encoding_size)])

            return np.array(results)

        if self.encoding_type in ['fasttext']:
            documents_clean = [clean_text(i) for i in documents]

            results = []
            for i in documents_clean:
                if i:
                    results.append(self.vectorizer.get_sentence_vector(i))
                    # results.append(self.vectorizer[i])
                else:
                    results.append(
                        np.array([0 for _ in range(self.encoding_size)]))

            return np.array(results)