def expand(self, token: pynini.FstLike) -> pynini.Fst: """Finds deduplication candidates for a token in a lexicon. Args: token: a "cooooool"-like token. Returns: An FST representing a lattice of possible matches. """ try: lattice = rewrite.rewrite_lattice(token, self._dedup) return rewrite.rewrite_lattice(lattice, self._lexicon) except rewrite.Error: return pynini.Fst()
def _rewrite_lattice( self, string: pynini.FstLike, token_type: Optional[pynini.TokenType] = None) -> pynini.Fst: """Applies all rules to an input string. Args: string: Input string or FST. token_type: Optional input token type, or symbol table. Returns: The lattice of output strings. Raises: Error: No rules requested. """ if not self.rules: raise Error("No rules requested") lattice = string for rule in self.rules: lattice = rewrite.rewrite_lattice(lattice, rule, token_type) else: if not isinstance(lattice, pynini.Fst): lattice = pynini.accep(lattice, token_type=token_type) return lattice
def decode(self, sentence: str) -> str: """Decodes sentence with the Chatspeak model + LM. Args: sentence: an input sentence. Returns: String representing the normalized sentence. """ it = iter(sentence.split()) token = next(it) lattice = self.token_lattice(token) for token in it: lattice.concat(" ") lattice.concat(self.token_lattice(token)) lattice.optimize() # Scores with LM. lattice = rewrite.rewrite_lattice(lattice, self._bytes_to_lm_mapper) lattice = rewrite.rewrite_lattice(lattice, self._lm) lattice = rewrite.rewrite_lattice(lattice, self._lm_to_bytes_mapper) return rewrite.lattice_to_top_string(lattice)
def expand(self, token: pynini.FstLike) -> pynini.Fst: """Finds regexps candidates for a token. Args: token: a "zomggg"-like token. Returns: An FST representing a lattice of possible matches. """ try: return rewrite.rewrite_lattice(token, self._regexps) except rewrite.Error: return pynini.Fst()
def process(self, pron, num_nbest=1): pron = pron.replace("sh", "š").replace("ou", u"õ").replace( "ae", "ä").replace("oe", "ö").replace("ue", "ü").replace( "kk", "K").replace("pp", "P").replace("tt", "T").replace(" ", "") orig_pron = accep(pron) lattice = (orig_pron @ self.inverse_transformer).project('output') lattice.optimize() if self.char_lm: lattice = rewrite.rewrite_lattice(lattice, self.bytes_to_lm_mapper) lattice = rewrite.rewrite_lattice(lattice, self.char_lm) lattice = rewrite.rewrite_lattice(lattice, self.lm_to_bytes_mapper) lattice.optimize() shortest_paths = shortestpath(lattice, nshortest=num_nbest, unique=False) result = [] for word, weight in zip(shortest_paths.paths().ostrings(), shortest_paths.paths().weights()): result.append((word, weight)) return result
def optimal_rewrites( string: pynini.FstLike, rule: pynini.Fst, input_token_type: Optional[TokenType] = None, output_token_type: Optional[TokenType] = None, threshold: float = 1, ) -> List[str]: """Returns all optimal rewrites. Args: string: Input string or FST. rule: Input rule WFST. input_token_type: Optional input token type, or symbol table. output_token_type: Optional output token type, or symbol table. threshold: Threshold for weights (1 is optimal only, 0 is for all paths) Returns: A tuple of output strings. """ lattice = rewrite.rewrite_lattice(string, rule, input_token_type) lattice = threshold_lattice_to_dfa(lattice, threshold, 4) return rewrite.lattice_to_strings(lattice, output_token_type)
def normalize( self, text: str, n_tagged: int, punct_post_process: bool = True, verbose: bool = False, ) -> str: """ Main function. Normalizes tokens from written to spoken form e.g. 12 kg -> twelve kilograms Args: text: string that may include semiotic classes n_tagged: number of tagged options to consider, -1 - to get all possible tagged options punct_post_process: whether to normalize punctuation verbose: whether to print intermediate meta information Returns: normalized text options (usually there are multiple ways of normalizing a given semiotic class) """ assert ( len(text.split()) < 500 ), "Your input is too long. Please split up the input into sentences, or strings with fewer than 500 words" original_text = text text = pre_process(text) # to handle [] text = text.strip() if not text: if verbose: print(text) return text text = pynini.escape(text) if self.lm: if self.lang not in ["en"]: raise ValueError(f"{self.lang} is not supported in LM mode") if self.lang == "en": try: lattice = rewrite.rewrite_lattice( text, self.tagger.fst_no_digits) except pynini.lib.rewrite.Error: lattice = rewrite.rewrite_lattice(text, self.tagger.fst) lattice = rewrite.lattice_to_nshortest(lattice, n_tagged) tagged_texts = [(x[1], float(x[2])) for x in lattice.paths().items()] tagged_texts.sort(key=lambda x: x[1]) tagged_texts, weights = list(zip(*tagged_texts)) else: if n_tagged == -1: if self.lang == "en": try: tagged_texts = rewrite.rewrites( text, self.tagger.fst_no_digits) except pynini.lib.rewrite.Error: tagged_texts = rewrite.rewrites(text, self.tagger.fst) else: tagged_texts = rewrite.rewrites(text, self.tagger.fst) else: if self.lang == "en": try: tagged_texts = rewrite.top_rewrites( text, self.tagger.fst_no_digits, nshortest=n_tagged) except pynini.lib.rewrite.Error: tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged) else: tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged) # non-deterministic Eng normalization uses tagger composed with verbalizer, no permutation in between if self.lang == "en": normalized_texts = tagged_texts else: normalized_texts = [] for tagged_text in tagged_texts: self._verbalize(tagged_text, normalized_texts, verbose=verbose) if len(normalized_texts) == 0: raise ValueError() if punct_post_process: # do post-processing based on Moses detokenizer if self.processor: normalized_texts = [ self.processor.detokenize([t]) for t in normalized_texts ] normalized_texts = [ post_process_punct(input=original_text, normalized_text=t) for t in normalized_texts ] if self.lm: return normalized_texts, weights normalized_texts = set(normalized_texts) return normalized_texts
def _run(self) -> typing.Generator[typing.Tuple[int, int]]: """Run the function""" db_engine = sqlalchemy.create_engine( f"sqlite:///{self.db_path}?mode=ro&nolock=1") with open(self.log_path, "w", encoding="utf8") as log_file, Session(db_engine) as session: dictionaries = (session.query(Dictionary).join( Dictionary.speakers).filter( Speaker.job_id == self.job_name).distinct()) tree_proc = subprocess.Popen( [thirdparty_binary("tree-info"), self.tree_path], encoding="utf8", stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) stdout, _ = tree_proc.communicate() context_width = 1 central_pos = 0 for line in stdout.split("\n"): text = line.strip().split(" ") if text[0] == "context-width": context_width = int(text[1]) elif text[0] == "central-position": central_pos = int(text[1]) out_disambig = os.path.join(self.working_dir, f"{self.job_name}.disambig") ilabels_temp = os.path.join(self.working_dir, f"{self.job_name}.ilabels") clg_path = os.path.join(self.working_dir, f"{self.job_name}.clg.temp") ha_out_disambig = os.path.join( self.working_dir, f"{self.job_name}.ha_out_disambig.temp") for d in dictionaries: fst_ark_path = self.fst_ark_paths[d.id] text_path = self.text_int_paths[d.id] if d.use_g2p: import pynini from pynini.lib import rewrite from montreal_forced_aligner.g2p.generator import threshold_lattice_to_dfa fst = pynini.Fst.read(d.lexicon_fst_path) token_type = pynini.SymbolTable.read_text( d.grapheme_symbol_table_path) utterances = ( session.query( Utterance.kaldi_id, Utterance.normalized_character_text).join( Utterance.speaker).filter( Utterance.ignored == False) # noqa .filter( Utterance.normalized_character_text != "").filter( Speaker.job_id == self.job_name).filter( Speaker.dictionary_id == d.id).order_by( Utterance.kaldi_id)) with open(fst_ark_path, "wb") as fst_output_file: for utt_id, full_text in utterances: full_text = f"<s> {full_text} </s>" lattice = rewrite.rewrite_lattice( full_text, fst, token_type) lattice = threshold_lattice_to_dfa(lattice, 2.0) input = lattice.write_to_string() clg_compose_proc = subprocess.Popen( [ thirdparty_binary("fstcomposecontext"), f"--context-size={context_width}", f"--central-position={central_pos}", f"--read-disambig-syms={d.disambiguation_symbols_int_path}", f"--write-disambig-syms={out_disambig}", ilabels_temp, "-", "-", ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) clg_sort_proc = subprocess.Popen( [ thirdparty_binary("fstarcsort"), "--sort_type=ilabel", "-", clg_path, ], stdin=clg_compose_proc.stdout, stderr=log_file, env=os.environ, ) clg_compose_proc.stdin.write(input) clg_compose_proc.stdin.flush() clg_compose_proc.stdin.close() clg_sort_proc.communicate() make_h_proc = subprocess.Popen( [ thirdparty_binary("make-h-transducer"), f"--disambig-syms-out={ha_out_disambig}", ilabels_temp, self.tree_path, self.model_path, ], stderr=log_file, stdout=subprocess.PIPE, env=os.environ, ) hclg_compose_proc = subprocess.Popen( [ thirdparty_binary("fsttablecompose"), "-", clg_path, "-" ], stderr=log_file, stdin=make_h_proc.stdout, stdout=subprocess.PIPE, env=os.environ, ) hclg_determinize_proc = subprocess.Popen( [ thirdparty_binary("fstdeterminizestar"), "--use-log=true" ], stdin=hclg_compose_proc.stdout, stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) hclg_rmsymbols_proc = subprocess.Popen( [ thirdparty_binary("fstrmsymbols"), ha_out_disambig ], stdin=hclg_determinize_proc.stdout, stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) hclg_rmeps_proc = subprocess.Popen( [thirdparty_binary("fstrmepslocal")], stdin=hclg_rmsymbols_proc.stdout, stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) hclg_minimize_proc = subprocess.Popen( [thirdparty_binary("fstminimizeencoded")], stdin=hclg_rmeps_proc.stdout, stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) hclg_self_loop_proc = subprocess.Popen( [ thirdparty_binary("add-self-loops"), "--self-loop-scale=0.1", "--reorder=true", self.model_path, "-", "-", ], stdin=hclg_minimize_proc.stdout, stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) stdout, _ = hclg_self_loop_proc.communicate() self.check_call(hclg_minimize_proc) fst_output_file.write(utt_id.encode("utf8") + b" ") fst_output_file.write(stdout) yield 1, 0 else: proc = subprocess.Popen( [ thirdparty_binary("compile-train-graphs"), f"--read-disambig-syms={d.disambiguation_symbols_int_path}", self.tree_path, self.model_path, d.lexicon_fst_path, f"ark:{text_path}", f"ark:{fst_ark_path}", ], stderr=subprocess.PIPE, encoding="utf8", env=os.environ, ) for line in proc.stderr: log_file.write(line) m = self.progress_pattern.match(line.strip()) if m: yield int(m.group("succeeded")), int( m.group("failed")) self.check_call(proc)
def decode(self, t9_input: pynini.FstLike) -> pynini.Fst: lattice = rewrite.rewrite_lattice(t9_input, self._decoder) return pynini.intersect(lattice, self._lexicon)
def __init__(self, cardinal: GraphFst, decimal: GraphFst, deterministic: bool = True): super().__init__(name="money", kind="classify", deterministic=deterministic) cardinal_graph = cardinal.graph graph_decimal_final = decimal.final_graph_wo_negative unit_singular = pynini.string_file( get_abs_path("data/currency/currency.tsv")) unit_plural = convert_space(unit_singular @ SINGULAR_TO_PLURAL) unit_singular = convert_space(unit_singular) graph_unit_singular = pynutil.insert( "currency: \"") + unit_singular + pynutil.insert("\"") graph_unit_plural = pynutil.insert( "currency: \"") + unit_plural + pynutil.insert("\"") singular_graph = (graph_unit_singular + pynutil.insert(" integer_part: \"") + pynini.cross("1", "one") + pynutil.insert("\"")) graph_decimal = graph_unit_plural + insert_space + graph_decimal_final if deterministic: graph_integer = (graph_unit_plural + pynutil.insert(" integer_part: \"") + ((NEMO_SIGMA - "1") @ cardinal_graph) + pynutil.insert("\"")) else: graph_integer = ( graph_unit_plural + pynutil.insert(" integer_part: \"") + ((NEMO_SIGMA - "1") @ (get_hundreds_graph(deterministic) | cardinal_graph)) + pynutil.insert("\"")) graph_decimal |= singular_graph + insert_space + graph_decimal_final graph_integer |= singular_graph final_graph = graph_integer | graph_decimal if not deterministic: currencies = load_labels( get_abs_path("data/currency/currency.tsv")) zero_graph = pynini.cross("0", "") | pynini.accep("0") # add minor currency part only when there are two digits after the point # .01 -> {zero one cent, one cent}, .05 -> {oh five, five cents} two_digits_fractional_part = ( NEMO_SIGMA + pynini.closure(NEMO_DIGIT) + ((pynini.accep(".") + (NEMO_DIGIT**(2) | zero_graph + (NEMO_DIGIT - "0"))) | pynutil.delete(".") + pynini.cross(pynini.closure("0", 1), ""))) integer_graph = None decimal_graph_with_minor = None decimal_graph_default = None for curr_symbol, curr_name in currencies: curr_symbol_graph = pynutil.delete(curr_symbol) graph_end = pynutil.insert(" currency: \"" + curr_symbol + "\"") preserve_order = pynutil.insert(" preserve_order: True") integer_part = decimal.graph_integer + graph_end + preserve_order # "$4" -> 'integer_part: "four" currency: "$" preserve_order: True' -> four dollars integer_graph_curr = curr_symbol_graph + integer_part # remove fractional part if it contains only zeros # "$4.00" -> 'integer_part: "four" currency: "$" preserve_order: True' -> four dollars integer_graph_curr |= pynini.compose( two_digits_fractional_part, integer_graph_curr) decimal_graph_with_minor_curr = ( curr_symbol_graph + pynini.closure(integer_part, 0, 1) + pynini.cross(".", " ") + decimal.graph_fractional + graph_end) # "$.5" -> 'fractional_part: "five" currency: "dollars"' -> point five dollars decimal_graph_default_curr = ( pynutil.delete("currency: \"" + pynini.compose(curr_symbol, unit_plural) + "\"") + delete_space + pynini.accep("fractional_part") + NEMO_SIGMA + pynutil.insert(" currency: \"" + pynini.compose(curr_symbol, unit_plural) + "\"")) # "$4.5" -> 'integer_part: "four" fractional_part: "five" currency: "dollars"' -> "four point five dollars" decimal_graph_default_curr |= ( pynutil.delete("currency: \"" + curr_name + pynini.closure(NEMO_NOT_QUOTE) + "\"") + delete_space + pynini.accep("integer_part") + NEMO_SIGMA + pynini.accep("fractional_part") + NEMO_SIGMA + pynutil.insert(" currency: \"" + pynini.compose(curr_symbol, unit_plural) + "\"")) # "£4 billion" -> 'integer_part: "four" quantity: "billion" currency: "pounds"' -> "four billion dollars" decimal_graph_default_curr |= ( pynutil.delete("currency: \"") + pynutil.delete( rewrite.rewrite_lattice( curr_symbol, pynini.compose(curr_symbol, unit_plural)) + "\" ") + pynini.difference(NEMO_SIGMA, "fractional_part") + pynutil.insert(" currency: \"" + pynini.compose(curr_symbol, unit_plural) + "\"")) decimal_graph_with_minor_curr = pynini.compose( two_digits_fractional_part, decimal_graph_with_minor_curr) decimal_graph_default_curr = pynini.compose( graph_decimal, decimal_graph_default_curr) integer_graph = (integer_graph_curr if integer_graph is None else pynini.union( integer_graph, integer_graph_curr)) decimal_graph_with_minor = (decimal_graph_with_minor_curr if decimal_graph_with_minor is None else pynini.union( decimal_graph_with_minor, decimal_graph_with_minor_curr)) decimal_graph_default = ( decimal_graph_default_curr if decimal_graph_default is None else pynini.union( decimal_graph_default, decimal_graph_default_curr)) final_graph = decimal_graph_with_minor | decimal_graph_default | integer_graph final_graph = self.add_tokens(final_graph) self.fst = final_graph.optimize()
def expand(self, token: pynini.FstLike) -> pynini.Fst: try: return rewrite.rewrite_lattice(token, self._lexicon) except rewrite.Error: return pynini.Fst()