Exemple #1
0
  def expand(self, token: pynini.FstLike) -> pynini.Fst:
    """Finds deduplication candidates for a token in a lexicon.

    Args:
      token: a "cooooool"-like token.

    Returns:
      An FST representing a lattice of possible matches.
    """
    try:
      lattice = rewrite.rewrite_lattice(token, self._dedup)
      return rewrite.rewrite_lattice(lattice, self._lexicon)
    except rewrite.Error:
      return pynini.Fst()
Exemple #2
0
    def _rewrite_lattice(
            self,
            string: pynini.FstLike,
            token_type: Optional[pynini.TokenType] = None) -> pynini.Fst:
        """Applies all rules to an input string.

    Args:
      string: Input string or FST.
      token_type: Optional input token type, or symbol table.

    Returns:
      The lattice of output strings.

    Raises:
      Error: No rules requested.
    """
        if not self.rules:
            raise Error("No rules requested")
        lattice = string
        for rule in self.rules:
            lattice = rewrite.rewrite_lattice(lattice, rule, token_type)
        else:
            if not isinstance(lattice, pynini.Fst):
                lattice = pynini.accep(lattice, token_type=token_type)
        return lattice
Exemple #3
0
    def decode(self, sentence: str) -> str:
        """Decodes sentence with the Chatspeak model + LM.

    Args:
      sentence: an input sentence.

    Returns:
      String representing the normalized sentence.
    """
        it = iter(sentence.split())
        token = next(it)
        lattice = self.token_lattice(token)
        for token in it:
            lattice.concat(" ")
            lattice.concat(self.token_lattice(token))
        lattice.optimize()
        # Scores with LM.
        lattice = rewrite.rewrite_lattice(lattice, self._bytes_to_lm_mapper)
        lattice = rewrite.rewrite_lattice(lattice, self._lm)
        lattice = rewrite.rewrite_lattice(lattice, self._lm_to_bytes_mapper)
        return rewrite.lattice_to_top_string(lattice)
Exemple #4
0
  def expand(self, token: pynini.FstLike) -> pynini.Fst:
    """Finds regexps candidates for a token.

    Args:
      token: a "zomggg"-like token.

    Returns:
      An FST representing a lattice of possible matches.
    """
    try:
      return rewrite.rewrite_lattice(token, self._regexps)
    except rewrite.Error:
      return pynini.Fst()
Exemple #5
0
    def process(self, pron, num_nbest=1):
        pron = pron.replace("sh", "š").replace("ou", u"õ").replace(
            "ae", "ä").replace("oe", "ö").replace("ue", "ü").replace(
                "kk", "K").replace("pp", "P").replace("tt",
                                                      "T").replace(" ", "")
        orig_pron = accep(pron)
        lattice = (orig_pron @ self.inverse_transformer).project('output')
        lattice.optimize()
        if self.char_lm:
            lattice = rewrite.rewrite_lattice(lattice, self.bytes_to_lm_mapper)
            lattice = rewrite.rewrite_lattice(lattice, self.char_lm)
            lattice = rewrite.rewrite_lattice(lattice, self.lm_to_bytes_mapper)

        lattice.optimize()

        shortest_paths = shortestpath(lattice,
                                      nshortest=num_nbest,
                                      unique=False)
        result = []
        for word, weight in zip(shortest_paths.paths().ostrings(),
                                shortest_paths.paths().weights()):
            result.append((word, weight))
        return result
Exemple #6
0
def optimal_rewrites(
    string: pynini.FstLike,
    rule: pynini.Fst,
    input_token_type: Optional[TokenType] = None,
    output_token_type: Optional[TokenType] = None,
    threshold: float = 1,
) -> List[str]:
    """Returns all optimal rewrites.
    Args:
    string: Input string or FST.
    rule: Input rule WFST.
    input_token_type: Optional input token type, or symbol table.
    output_token_type: Optional output token type, or symbol table.
    threshold: Threshold for weights (1 is optimal only, 0 is for all paths)
    Returns:
    A tuple of output strings.
    """
    lattice = rewrite.rewrite_lattice(string, rule, input_token_type)
    lattice = threshold_lattice_to_dfa(lattice, threshold, 4)
    return rewrite.lattice_to_strings(lattice, output_token_type)
Exemple #7
0
    def normalize(
        self,
        text: str,
        n_tagged: int,
        punct_post_process: bool = True,
        verbose: bool = False,
    ) -> str:
        """
        Main function. Normalizes tokens from written to spoken form
            e.g. 12 kg -> twelve kilograms

        Args:
            text: string that may include semiotic classes
            n_tagged: number of tagged options to consider, -1 - to get all possible tagged options
            punct_post_process: whether to normalize punctuation
            verbose: whether to print intermediate meta information

        Returns:
            normalized text options (usually there are multiple ways of normalizing a given semiotic class)
        """

        assert (
            len(text.split()) < 500
        ), "Your input is too long. Please split up the input into sentences, or strings with fewer than 500 words"
        original_text = text
        text = pre_process(text)  # to handle []

        text = text.strip()
        if not text:
            if verbose:
                print(text)
            return text
        text = pynini.escape(text)

        if self.lm:
            if self.lang not in ["en"]:
                raise ValueError(f"{self.lang} is not supported in LM mode")

            if self.lang == "en":
                try:
                    lattice = rewrite.rewrite_lattice(
                        text, self.tagger.fst_no_digits)
                except pynini.lib.rewrite.Error:
                    lattice = rewrite.rewrite_lattice(text, self.tagger.fst)
                lattice = rewrite.lattice_to_nshortest(lattice, n_tagged)
                tagged_texts = [(x[1], float(x[2]))
                                for x in lattice.paths().items()]
                tagged_texts.sort(key=lambda x: x[1])
                tagged_texts, weights = list(zip(*tagged_texts))
        else:
            if n_tagged == -1:
                if self.lang == "en":
                    try:
                        tagged_texts = rewrite.rewrites(
                            text, self.tagger.fst_no_digits)
                    except pynini.lib.rewrite.Error:
                        tagged_texts = rewrite.rewrites(text, self.tagger.fst)
                else:
                    tagged_texts = rewrite.rewrites(text, self.tagger.fst)
            else:
                if self.lang == "en":
                    try:
                        tagged_texts = rewrite.top_rewrites(
                            text,
                            self.tagger.fst_no_digits,
                            nshortest=n_tagged)
                    except pynini.lib.rewrite.Error:
                        tagged_texts = rewrite.top_rewrites(text,
                                                            self.tagger.fst,
                                                            nshortest=n_tagged)
                else:
                    tagged_texts = rewrite.top_rewrites(text,
                                                        self.tagger.fst,
                                                        nshortest=n_tagged)

        # non-deterministic Eng normalization uses tagger composed with verbalizer, no permutation in between
        if self.lang == "en":
            normalized_texts = tagged_texts
        else:
            normalized_texts = []
            for tagged_text in tagged_texts:
                self._verbalize(tagged_text, normalized_texts, verbose=verbose)

        if len(normalized_texts) == 0:
            raise ValueError()

        if punct_post_process:
            # do post-processing based on Moses detokenizer
            if self.processor:
                normalized_texts = [
                    self.processor.detokenize([t]) for t in normalized_texts
                ]
                normalized_texts = [
                    post_process_punct(input=original_text, normalized_text=t)
                    for t in normalized_texts
                ]

        if self.lm:
            return normalized_texts, weights

        normalized_texts = set(normalized_texts)
        return normalized_texts
    def _run(self) -> typing.Generator[typing.Tuple[int, int]]:
        """Run the function"""
        db_engine = sqlalchemy.create_engine(
            f"sqlite:///{self.db_path}?mode=ro&nolock=1")

        with open(self.log_path, "w",
                  encoding="utf8") as log_file, Session(db_engine) as session:
            dictionaries = (session.query(Dictionary).join(
                Dictionary.speakers).filter(
                    Speaker.job_id == self.job_name).distinct())

            tree_proc = subprocess.Popen(
                [thirdparty_binary("tree-info"), self.tree_path],
                encoding="utf8",
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
            )
            stdout, _ = tree_proc.communicate()
            context_width = 1
            central_pos = 0
            for line in stdout.split("\n"):
                text = line.strip().split(" ")
                if text[0] == "context-width":
                    context_width = int(text[1])
                elif text[0] == "central-position":
                    central_pos = int(text[1])
            out_disambig = os.path.join(self.working_dir,
                                        f"{self.job_name}.disambig")
            ilabels_temp = os.path.join(self.working_dir,
                                        f"{self.job_name}.ilabels")
            clg_path = os.path.join(self.working_dir,
                                    f"{self.job_name}.clg.temp")
            ha_out_disambig = os.path.join(
                self.working_dir, f"{self.job_name}.ha_out_disambig.temp")
            for d in dictionaries:
                fst_ark_path = self.fst_ark_paths[d.id]
                text_path = self.text_int_paths[d.id]
                if d.use_g2p:
                    import pynini
                    from pynini.lib import rewrite

                    from montreal_forced_aligner.g2p.generator import threshold_lattice_to_dfa

                    fst = pynini.Fst.read(d.lexicon_fst_path)
                    token_type = pynini.SymbolTable.read_text(
                        d.grapheme_symbol_table_path)
                    utterances = (
                        session.query(
                            Utterance.kaldi_id,
                            Utterance.normalized_character_text).join(
                                Utterance.speaker).filter(
                                    Utterance.ignored == False)  # noqa
                        .filter(
                            Utterance.normalized_character_text != "").filter(
                                Speaker.job_id == self.job_name).filter(
                                    Speaker.dictionary_id == d.id).order_by(
                                        Utterance.kaldi_id))
                    with open(fst_ark_path, "wb") as fst_output_file:
                        for utt_id, full_text in utterances:
                            full_text = f"<s> {full_text} </s>"
                            lattice = rewrite.rewrite_lattice(
                                full_text, fst, token_type)
                            lattice = threshold_lattice_to_dfa(lattice, 2.0)
                            input = lattice.write_to_string()
                            clg_compose_proc = subprocess.Popen(
                                [
                                    thirdparty_binary("fstcomposecontext"),
                                    f"--context-size={context_width}",
                                    f"--central-position={central_pos}",
                                    f"--read-disambig-syms={d.disambiguation_symbols_int_path}",
                                    f"--write-disambig-syms={out_disambig}",
                                    ilabels_temp,
                                    "-",
                                    "-",
                                ],
                                stdin=subprocess.PIPE,
                                stdout=subprocess.PIPE,
                                stderr=log_file,
                                env=os.environ,
                            )
                            clg_sort_proc = subprocess.Popen(
                                [
                                    thirdparty_binary("fstarcsort"),
                                    "--sort_type=ilabel",
                                    "-",
                                    clg_path,
                                ],
                                stdin=clg_compose_proc.stdout,
                                stderr=log_file,
                                env=os.environ,
                            )
                            clg_compose_proc.stdin.write(input)
                            clg_compose_proc.stdin.flush()
                            clg_compose_proc.stdin.close()
                            clg_sort_proc.communicate()

                            make_h_proc = subprocess.Popen(
                                [
                                    thirdparty_binary("make-h-transducer"),
                                    f"--disambig-syms-out={ha_out_disambig}",
                                    ilabels_temp,
                                    self.tree_path,
                                    self.model_path,
                                ],
                                stderr=log_file,
                                stdout=subprocess.PIPE,
                                env=os.environ,
                            )
                            hclg_compose_proc = subprocess.Popen(
                                [
                                    thirdparty_binary("fsttablecompose"), "-",
                                    clg_path, "-"
                                ],
                                stderr=log_file,
                                stdin=make_h_proc.stdout,
                                stdout=subprocess.PIPE,
                                env=os.environ,
                            )

                            hclg_determinize_proc = subprocess.Popen(
                                [
                                    thirdparty_binary("fstdeterminizestar"),
                                    "--use-log=true"
                                ],
                                stdin=hclg_compose_proc.stdout,
                                stdout=subprocess.PIPE,
                                stderr=log_file,
                                env=os.environ,
                            )
                            hclg_rmsymbols_proc = subprocess.Popen(
                                [
                                    thirdparty_binary("fstrmsymbols"),
                                    ha_out_disambig
                                ],
                                stdin=hclg_determinize_proc.stdout,
                                stdout=subprocess.PIPE,
                                stderr=log_file,
                                env=os.environ,
                            )
                            hclg_rmeps_proc = subprocess.Popen(
                                [thirdparty_binary("fstrmepslocal")],
                                stdin=hclg_rmsymbols_proc.stdout,
                                stdout=subprocess.PIPE,
                                stderr=log_file,
                                env=os.environ,
                            )
                            hclg_minimize_proc = subprocess.Popen(
                                [thirdparty_binary("fstminimizeencoded")],
                                stdin=hclg_rmeps_proc.stdout,
                                stdout=subprocess.PIPE,
                                stderr=log_file,
                                env=os.environ,
                            )
                            hclg_self_loop_proc = subprocess.Popen(
                                [
                                    thirdparty_binary("add-self-loops"),
                                    "--self-loop-scale=0.1",
                                    "--reorder=true",
                                    self.model_path,
                                    "-",
                                    "-",
                                ],
                                stdin=hclg_minimize_proc.stdout,
                                stdout=subprocess.PIPE,
                                stderr=log_file,
                                env=os.environ,
                            )

                            stdout, _ = hclg_self_loop_proc.communicate()
                            self.check_call(hclg_minimize_proc)
                            fst_output_file.write(utt_id.encode("utf8") + b" ")
                            fst_output_file.write(stdout)
                            yield 1, 0

                else:
                    proc = subprocess.Popen(
                        [
                            thirdparty_binary("compile-train-graphs"),
                            f"--read-disambig-syms={d.disambiguation_symbols_int_path}",
                            self.tree_path,
                            self.model_path,
                            d.lexicon_fst_path,
                            f"ark:{text_path}",
                            f"ark:{fst_ark_path}",
                        ],
                        stderr=subprocess.PIPE,
                        encoding="utf8",
                        env=os.environ,
                    )
                    for line in proc.stderr:
                        log_file.write(line)
                        m = self.progress_pattern.match(line.strip())
                        if m:
                            yield int(m.group("succeeded")), int(
                                m.group("failed"))
                    self.check_call(proc)
Exemple #9
0
 def decode(self, t9_input: pynini.FstLike) -> pynini.Fst:
     lattice = rewrite.rewrite_lattice(t9_input, self._decoder)
     return pynini.intersect(lattice, self._lexicon)
Exemple #10
0
    def __init__(self,
                 cardinal: GraphFst,
                 decimal: GraphFst,
                 deterministic: bool = True):
        super().__init__(name="money",
                         kind="classify",
                         deterministic=deterministic)
        cardinal_graph = cardinal.graph
        graph_decimal_final = decimal.final_graph_wo_negative

        unit_singular = pynini.string_file(
            get_abs_path("data/currency/currency.tsv"))
        unit_plural = convert_space(unit_singular @ SINGULAR_TO_PLURAL)
        unit_singular = convert_space(unit_singular)

        graph_unit_singular = pynutil.insert(
            "currency: \"") + unit_singular + pynutil.insert("\"")
        graph_unit_plural = pynutil.insert(
            "currency: \"") + unit_plural + pynutil.insert("\"")

        singular_graph = (graph_unit_singular +
                          pynutil.insert(" integer_part: \"") +
                          pynini.cross("1", "one") + pynutil.insert("\""))

        graph_decimal = graph_unit_plural + insert_space + graph_decimal_final

        if deterministic:
            graph_integer = (graph_unit_plural +
                             pynutil.insert(" integer_part: \"") +
                             ((NEMO_SIGMA - "1") @ cardinal_graph) +
                             pynutil.insert("\""))
        else:
            graph_integer = (
                graph_unit_plural + pynutil.insert(" integer_part: \"") +
                ((NEMO_SIGMA - "1")
                 @ (get_hundreds_graph(deterministic) | cardinal_graph)) +
                pynutil.insert("\""))
            graph_decimal |= singular_graph + insert_space + graph_decimal_final

        graph_integer |= singular_graph

        final_graph = graph_integer | graph_decimal

        if not deterministic:
            currencies = load_labels(
                get_abs_path("data/currency/currency.tsv"))
            zero_graph = pynini.cross("0", "") | pynini.accep("0")
            # add minor currency part only when there are two digits after the point
            # .01 -> {zero one cent, one cent}, .05 -> {oh five, five cents}
            two_digits_fractional_part = (
                NEMO_SIGMA + pynini.closure(NEMO_DIGIT) +
                ((pynini.accep(".") + (NEMO_DIGIT**(2) | zero_graph +
                                       (NEMO_DIGIT - "0")))
                 | pynutil.delete(".") +
                 pynini.cross(pynini.closure("0", 1), "")))

            integer_graph = None
            decimal_graph_with_minor = None
            decimal_graph_default = None

            for curr_symbol, curr_name in currencies:
                curr_symbol_graph = pynutil.delete(curr_symbol)
                graph_end = pynutil.insert(" currency: \"" + curr_symbol +
                                           "\"")
                preserve_order = pynutil.insert(" preserve_order: True")
                integer_part = decimal.graph_integer + graph_end + preserve_order

                # "$4" -> 'integer_part: "four" currency: "$" preserve_order: True' -> four dollars
                integer_graph_curr = curr_symbol_graph + integer_part
                # remove fractional part if it contains only zeros
                # "$4.00" -> 'integer_part: "four" currency: "$" preserve_order: True' -> four dollars
                integer_graph_curr |= pynini.compose(
                    two_digits_fractional_part, integer_graph_curr)
                decimal_graph_with_minor_curr = (
                    curr_symbol_graph + pynini.closure(integer_part, 0, 1) +
                    pynini.cross(".", " ") + decimal.graph_fractional +
                    graph_end)

                # "$.5" -> 'fractional_part: "five" currency: "dollars"' -> point five dollars
                decimal_graph_default_curr = (
                    pynutil.delete("currency: \"" +
                                   pynini.compose(curr_symbol, unit_plural) +
                                   "\"") + delete_space +
                    pynini.accep("fractional_part") + NEMO_SIGMA +
                    pynutil.insert(" currency: \"" +
                                   pynini.compose(curr_symbol, unit_plural) +
                                   "\""))

                # "$4.5" -> 'integer_part: "four" fractional_part: "five" currency: "dollars"' -> "four point five dollars"
                decimal_graph_default_curr |= (
                    pynutil.delete("currency: \"" + curr_name +
                                   pynini.closure(NEMO_NOT_QUOTE) + "\"") +
                    delete_space + pynini.accep("integer_part") + NEMO_SIGMA +
                    pynini.accep("fractional_part") + NEMO_SIGMA +
                    pynutil.insert(" currency: \"" +
                                   pynini.compose(curr_symbol, unit_plural) +
                                   "\""))

                # "£4 billion" -> 'integer_part: "four" quantity: "billion" currency: "pounds"' -> "four billion dollars"
                decimal_graph_default_curr |= (
                    pynutil.delete("currency: \"") + pynutil.delete(
                        rewrite.rewrite_lattice(
                            curr_symbol,
                            pynini.compose(curr_symbol, unit_plural)) + "\" ")
                    + pynini.difference(NEMO_SIGMA, "fractional_part") +
                    pynutil.insert(" currency: \"" +
                                   pynini.compose(curr_symbol, unit_plural) +
                                   "\""))

                decimal_graph_with_minor_curr = pynini.compose(
                    two_digits_fractional_part, decimal_graph_with_minor_curr)
                decimal_graph_default_curr = pynini.compose(
                    graph_decimal, decimal_graph_default_curr)

                integer_graph = (integer_graph_curr
                                 if integer_graph is None else pynini.union(
                                     integer_graph, integer_graph_curr))
                decimal_graph_with_minor = (decimal_graph_with_minor_curr
                                            if decimal_graph_with_minor is None
                                            else pynini.union(
                                                decimal_graph_with_minor,
                                                decimal_graph_with_minor_curr))
                decimal_graph_default = (
                    decimal_graph_default_curr
                    if decimal_graph_default is None else pynini.union(
                        decimal_graph_default, decimal_graph_default_curr))

            final_graph = decimal_graph_with_minor | decimal_graph_default | integer_graph

        final_graph = self.add_tokens(final_graph)
        self.fst = final_graph.optimize()
Exemple #11
0
 def expand(self, token: pynini.FstLike) -> pynini.Fst:
   try:
     return rewrite.rewrite_lattice(token, self._lexicon)
   except rewrite.Error:
     return pynini.Fst()