示例#1
0
 def get_vocabulary(self, dump_path: str,
                    markup_path: str) -> StressVocabulary:
     if self.vocabulary is None:
         self.vocabulary = StressVocabulary()
         if os.path.isfile(dump_path):
             self.vocabulary.load(dump_path)
         elif markup_path is not None:
             self.vocabulary.parse(markup_path)
     return self.vocabulary
示例#2
0
 def test_vocabulary(self):
     dump_file = os.path.join(EXAMPLES_DIR, "temp.pickle")
     vocabulary = StressVocabulary(dump_file, MARKUP_XML_EXAMPLE)
     self.assertTrue(os.path.exists(dump_file))
     os.remove(dump_file)
     try:
         self.assertTrue(vocabulary.get_word(0) is not None)
     except IndexError:
         self.assertTrue(False)
示例#3
0
 def get_words(self, vocabulary: StressVocabulary) -> List[str]:
     """
     Получить слова текущего пути.
     
     :param vocabulary: словарь.
     :return: слова.
     """
     return [vocabulary.get_word(word_index).text.lower() for word_index in self.indices]
示例#4
0
 def test_markov(self):
     for n in range(2, 5):
         vocab_dump_file = os.path.join(EXAMPLES_DIR, "vocab.pickle")
         markov_dump_file = os.path.join(EXAMPLES_DIR, "markov.pickle")
         vocabulary = StressVocabulary(vocab_dump_file, MARKUP_XML_EXAMPLE)
         markov = MarkovModelContainer(markov_dump_file,
                                       vocabulary,
                                       MARKUP_XML_EXAMPLE,
                                       n_grams=n)
         self.assertTrue(os.path.exists(vocab_dump_file))
         self.assertTrue(os.path.exists(markov_dump_file))
         os.remove(vocab_dump_file)
         os.remove(markov_dump_file)
         self.assertEqual(vocabulary.size() - n + 1,
                          len(markov.transitions))
         self.assertEqual(
             sum([
                 sum(transition.values())
                 for transition in markov.transitions.values()
             ]),
             vocabulary.size() - n + 1)
示例#5
0
    def get_generator(self, model_path: str, token_vocab_path: str,
                      stress_vocab_dump_path: str) -> Generator:
        if self.generator is None:
            assert os.path.isdir(model_path) and os.path.isdir(
                token_vocab_path)
            vocabulary = Vocabulary.from_files(token_vocab_path)
            stress_vocabulary = StressVocabulary()
            if not os.path.isfile(stress_vocab_dump_path):
                stress_vocabulary = inflate_stress_vocabulary(
                    vocabulary, self.get_stress_predictor())
                stress_vocabulary.save(stress_vocab_dump_path)
            else:
                stress_vocabulary.load(stress_vocab_dump_path)

            eos_index = vocabulary.get_token_index(END_SYMBOL)
            unk_index = vocabulary.get_token_index(DEFAULT_OOV_TOKEN)
            exclude_transform = ExcludeTransform((unk_index, eos_index))

            model = LanguageModel.load(model_path,
                                       vocabulary_dir=token_vocab_path,
                                       transforms=[
                                           exclude_transform,
                                       ])
            self.generator = Generator(model, vocabulary, stress_vocabulary,
                                       eos_index)
        return self.generator
示例#6
0
文件: api.py 项目: che1974/rupo
 def get_lstm_generator(self, model_path: str,
                        word_form_vocab_dump_path: str,
                        stress_vocab_dump_path: str,
                        gram_dump_path: str) -> Generator:
     if self.lstm_generator is None:
         lstm = LSTMModelContainer(model_path, word_form_vocab_dump_path,
                                   gram_dump_path)
         word_form_vocabulary = WordFormVocabulary(
             word_form_vocab_dump_path)
         vocabulary = StressVocabulary(stress_vocab_dump_path)
         self.lstm_generator = Generator(lstm, vocabulary,
                                         word_form_vocabulary)
     return self.lstm_generator
示例#7
0
    def filter_model(self, model: np.array,
                     vocabulary: StressVocabulary) -> np.array:
        """
        Фильтрация языковой модели.

        :param model: изначальная модель.
        :param vocabulary: словарь
        :return: модель после фильтрации и нормирования.
        """
        for i in range(len(model)):
            if not self.filter_word(vocabulary.get_word(i)):
                model[i] = 0.0
        return model
示例#8
0
 def inflate_vocab(self, dump_path, top_n=None) -> None:
     """
     Получение словаря с ударениями по этому словарю.
     
     :param top_n: сколько первых записей взять?
     :param dump_path: путь, куда сохранить словарь.
     """
     from rupo.main.vocabulary import StressVocabulary
     from rupo.stress.word import StressedWord, Stress
     from rupo.stress.predictor import CombinedStressPredictor
     vocab = StressVocabulary(dump_path)
     stress_predictor = CombinedStressPredictor()
     forms = self.word_forms
     if top_n is not None:
         forms = forms[:top_n]
     for index, word_form in tqdm(enumerate(forms), desc="Accenting words"):
         text = word_form.text
         stresses = [Stress(pos, Stress.Type.PRIMARY) for pos in stress_predictor.predict(text)]
         word = StressedWord(text, set(stresses))
         vocab.add_word(word, index)
     vocab.save()
示例#9
0
文件: api.py 项目: che1974/rupo
 def get_vocabulary(self, dump_path: str,
                    markup_path: str) -> StressVocabulary:
     if self.vocabulary is None:
         self.vocabulary = StressVocabulary(dump_path, markup_path)
     return self.vocabulary
示例#10
0
class Engine:
    def __init__(self, language="ru"):
        self.language = language  # type: str
        self.vocabulary = None  # type: StressVocabulary
        self.generator = None  # type: Generator
        self.g2p_models = dict()  # type: Dict[str, RNNG2PModel]
        self.stress_predictors = dict()  # type: Dict[str, StressPredictor]

    def load(self,
             stress_model_path: str,
             zalyzniak_dict: str,
             raw_stress_dict_path=None,
             stress_trie_path=None):
        self.g2p_models = dict()
        self.stress_predictors = dict()
        if not os.path.isdir(DATA_DIR):
            os.makedirs(DATA_DIR)
        if not os.path.isdir(DICT_DIR):
            os.makedirs(DICT_DIR)
        self.get_stress_predictor(self.language, stress_model_path,
                                  raw_stress_dict_path, stress_trie_path,
                                  zalyzniak_dict)

    def get_vocabulary(self, dump_path: str,
                       markup_path: str) -> StressVocabulary:
        if self.vocabulary is None:
            self.vocabulary = StressVocabulary()
            if os.path.isfile(dump_path):
                self.vocabulary.load(dump_path)
            elif markup_path is not None:
                self.vocabulary.parse(markup_path)
        return self.vocabulary

    def get_generator(self, model_path: str, token_vocab_path: str,
                      stress_vocab_dump_path: str) -> Generator:
        if self.generator is None:
            assert os.path.isdir(model_path) and os.path.isdir(
                token_vocab_path)
            vocabulary = Vocabulary.from_files(token_vocab_path)
            stress_vocabulary = StressVocabulary()
            if not os.path.isfile(stress_vocab_dump_path):
                stress_vocabulary = inflate_stress_vocabulary(
                    vocabulary, self.get_stress_predictor())
                stress_vocabulary.save(stress_vocab_dump_path)
            else:
                stress_vocabulary.load(stress_vocab_dump_path)

            eos_index = vocabulary.get_token_index(END_SYMBOL)
            unk_index = vocabulary.get_token_index(DEFAULT_OOV_TOKEN)
            exclude_transform = ExcludeTransform((unk_index, eos_index))

            model = LanguageModel.load(model_path,
                                       vocabulary_dir=token_vocab_path,
                                       transforms=[
                                           exclude_transform,
                                       ])
            self.generator = Generator(model, vocabulary, stress_vocabulary,
                                       eos_index)
        return self.generator

    def get_stress_predictor(self,
                             language="ru",
                             stress_model_path: str = None,
                             raw_stress_dict_path=None,
                             stress_trie_path=None,
                             zalyzniak_dict=ZALYZNYAK_DICT,
                             cmu_dict=CMU_DICT):
        if self.stress_predictors.get(language) is None:
            self.stress_predictors[language] = CombinedStressPredictor(
                language, stress_model_path, raw_stress_dict_path,
                stress_trie_path, zalyzniak_dict, cmu_dict)
        return self.stress_predictors[language]

    def get_g2p_model(self, language="ru", model_path=None):
        if self.g2p_models.get(language) is None:
            self.g2p_models[language] = RNNG2PModel(language=language)
            if language == "ru" and model_path is None:
                model_path = RU_G2P_DEFAULT_MODEL
            elif language == "en" and model_path is None:
                model_path = EN_G2P_DEFAULT_MODEL
            else:
                return None
            self.g2p_models[language].load(model_path)
        return self.g2p_models[language]

    def get_stresses(self, word: str, language: str = "ru") -> List[int]:
        """
        :param word: слово.
        :param language: язык.
        :return: ударения слова.
        """
        return self.get_stress_predictor(language).predict(word)

    @staticmethod
    def get_word_syllables(word: str) -> List[str]:
        """
        :param word: слово.
        :return: его слоги.
        """
        return [syllable.text for syllable in Graphemes.get_syllables(word)]

    @staticmethod
    def count_syllables(word: str) -> int:
        """
        :param word: слово.
        :return: количество слогов в нём.
        """
        return len(Graphemes.get_syllables(word))

    def get_markup(self, text: str, language: str = "ru") -> Markup:
        """
        :param text: текст.
        :param language: язык.
        :return: его разметка по словарю.
        """
        return Markup.process_text(text, self.get_stress_predictor(language))

    def get_improved_markup(
            self,
            text: str,
            language: str = "ru") -> Tuple[Markup, ClassificationResult]:
        """
        :param text: текст.
        :param language: язык.
        :return: его разметка по словарю, классификатору метру и  ML классификатору.
        """
        markup = Markup.process_text(text, self.get_stress_predictor(language))
        return MetreClassifier.improve_markup(markup)

    def classify_metre(self, text: str, language: str = "ru") -> str:
        """
        :param text: текст.
        :param language: язык.
        :return: его метр.
        """
        return MetreClassifier.classify_metre(
            Markup.process_text(text,
                                self.get_stress_predictor(language))).metre

    def generate_markups(self, input_path: str, input_type: FileType,
                         output_path: str, output_type: FileType) -> None:
        """
        Генерация разметок по текстам.

        :param input_path: путь к папке/файлу с текстом.
        :param input_type: тип файлов с текстов.
        :param output_path: путь к файлу с итоговыми разметками.
        :param output_type: тип итогового файла.
        """
        markups = Reader.read_markups(input_path, input_type, False,
                                      self.get_stress_predictor())
        writer = Writer(output_type, output_path)
        writer.open()
        for markup in markups:
            writer.write_markup(markup)
        writer.close()

    def is_rhyme(self, word1: str, word2: str) -> bool:
        """
        :param word1: первое слово.
        :param word2: второе слово.
        :return: рифмуются ли слова.
        """
        markup_word1 = self.get_markup(word1).lines[0].words[0]
        markup_word1.set_stresses(self.get_stresses(word1))
        markup_word2 = self.get_markup(word2).lines[0].words[0]
        markup_word2.set_stresses(self.get_stresses(word2))
        return Rhymes.is_rhyme(markup_word1, markup_word2)

    def generate_poem(self,
                      model_path: str,
                      token_vocab_path: str = None,
                      stress_vocab_path: str = None,
                      metre_schema: str = "-+",
                      rhyme_pattern: str = "abab",
                      n_syllables: int = 8,
                      sampling_k: int = None,
                      beam_width: int = None,
                      seed: int = 1337,
                      temperature: float = 1.0,
                      last_text: str = "") -> str:
        """
        Сгенерировать стих. Нужно задать либо sampling_k, либо beam_width.

        :param model_path: путь к модели.
        :param token_vocab_path: путь к словарю.
        :param stress_vocab_path: путь к словарю ударений.
        :param metre_schema: схема метра.
        :param rhyme_pattern: схема рифм.
        :param n_syllables: количество слогов в строке.
        :param sampling_k: top-k при семплинге
        :param beam_width: ширина лучевого поиска.
        :param seed: seed
        :param temperature: температура генерации
        :param last_text: последняя строчка
        :return: стих. None, если генерация не была успешной.
        """
        token_vocab_path = token_vocab_path or os.path.join(
            model_path, "vocabulary")
        stress_vocab_path = stress_vocab_path or os.path.join(
            model_path, "stress.pickle")
        generator = self.get_generator(model_path, token_vocab_path,
                                       stress_vocab_path)
        poem = generator.generate_poem(metre_schema=metre_schema,
                                       rhyme_pattern=rhyme_pattern,
                                       n_syllables=n_syllables,
                                       sampling_k=sampling_k,
                                       beam_width=beam_width,
                                       temperature=temperature,
                                       seed=seed,
                                       last_text=last_text)
        return poem

    def get_word_rhymes(self,
                        word: str,
                        vocab_dump_path: str,
                        markup_path: str = None) -> List[str]:
        """
        Поиск рифмы для данного слова.

        :param word: слово.
        :param vocab_dump_path: путь, куда сохраняется словарь.
        :param markup_path: путь к разметкам.
        :return: список рифм.
        """
        markup_word = self.get_markup(word).lines[0].words[0]
        markup_word.set_stresses(self.get_stresses(word))
        rhymes = []
        vocabulary = self.get_vocabulary(vocab_dump_path, markup_path)
        for i in range(vocabulary.size()):
            if Rhymes.is_rhyme(markup_word, vocabulary.get_word(i)):
                rhymes.append(vocabulary.get_word(i).text.lower())
        return rhymes