def matches(self, istring: pynini.FstLike, ostring: pynini.FstLike, input_token_type: Optional[pynini.TokenType] = None, output_token_type: Optional[pynini.TokenType] = None) -> bool: """Returns whether or not the rule cascade allows an input/output pair. Args: istring: Input string or FST. ostring: Output string or FST. input_token_type: Optional input token type, or symbol table. output_token_type: Optional output token type, or symbol table. Returns: Whether the input-output pair is generated by the rule. """ lattice = self._rewrite_lattice(istring, input_token_type) # TODO(kbg): Consider using `contextlib.nullcontext` here instead. if output_token_type is None: lattice = pynini.intersect(lattice, ostring, compose_filter="sequence") else: with pynini.default_token_type(output_token_type): lattice = pynini.intersect(lattice, ostring, compose_filter="sequence") return lattice.start() != pynini.NO_STATE_ID
def gen(fsa, accept): R = functools.partial(pynini.randgen) loop = 10 n = 100 for i in range(loop): num = int(n + n*i*0.1) rand = R(pynini.intersect(fsa, accept), npath=num, seed=0, select="uniform", max_length=100, weighted=False) return list_string_set(rand)
def _intersect_fsa_fsa(fsa1: FSA, fsa2: FSA) -> FSA: intersection = FSA() intersection.fst = pynini.intersect(fsa1.fst, fsa2.fst) intersection.symbol_table = fsa1.symbol_table intersection.token_to_key = fsa1.token_to_key intersection.key_to_token = fsa1.key_to_token intersection.compile() return intersection
def get_best_expansion(self, expansions): print("combining expansions and LM ...") best_exp = pn.intersect(expansions, self.LM) print("optimizing intersection ...") best_exp.optimize() #best_exp.draw('best.dot') shortest_path = pn.shortestpath(best_exp, nshortest=1).optimize() #shortest_path.draw('shortest.dot') return shortest_path
def getPosString(fsa, min_len, max_len): fsa_dict = {} pos_str_dict = {} for i in range(min_len, max_len + 1): fsa_dict[i] = pynini.intersect(fsa, pynini.closure(sigma, i, i)) pos_str_dict[i] = list( np.random.permutation(listStringSet(fsa_dict[i]))) print(pos_str_dict[i]) return pos_str_dict
def _intersect_pda_pda(pda1: PDA, pda2: PDA) -> PDA: intersection = PDA() intersection.fst = pynini.intersect(pda1.fst, pda2.fst) intersection.symbol_table = pda1.symbol_table intersection.parens = pda1.parens intersection.open_key_to_close_key = pda1.open_key_to_close_key intersection.close_key_to_open_key = pda1.close_key_to_open_key intersection.token_to_key = pda1.token_to_key intersection.key_to_token = pda1.key_to_token intersection.compile() return intersection
def __call__(self, string: str) -> str: try: lattice = self._rewrite_s1(string) except rewrite.Error: return "<composition failure>" filtered = pynini.intersect(lattice, self._lexicon) # If intersection fails, take the top string from the original lattice. # But if it succeeds, take the top string from the filtered lattice. if filtered.start() == pynini.NO_STATE_ID: return self._rewrite_s2(lattice) else: return self._rewrite_s2(filtered)
def _language_model_scoring(self, verbal_arr): #word_fst, self.oov_queue = self.compiler.fst_stringcompile_words(verbal_arr) word_fst, self.replacement_dict = self.compiler.fst_stringcompile_words( verbal_arr) #self.replacement_dict = self.compiler.replacement_dict word_fst.set_output_symbols(self.word_symbols) word_fst.optimize() word_fst.project(True) word_fst.arcsort() #word_fst.draw('word_fst.dot') lm_intersect = pn.intersect(word_fst, self.lm) lm_intersect.optimize() #lm_intersect.draw('lm_intersect.dot') shortest_path = pn.shortestpath(lm_intersect).optimize() return shortest_path
def _assert_fst_sampled_behavior( self, fsts: List[pynini.Fst], token_type: pynini.TokenType, samples: int, assert_function: Callable[[pynini.Fst, pynini.Fst], None]) -> None: """Asserts that FST composed on samples is follow a specific behavior. This samples from first FST's input projection in order to assert a behavior when composed with the FSTs. This is used in lieu of statically verifying that this composition has a specific property as that isn't easy to answer for non-deterministic FSTs. If token_type is set to "byte", then the input projection of the FST is intersected with the definition of the closure over valid UTF-8 characters to ensure all samples are valid UTF-8 strings that Python can handle. The maximum length of a sample is set to 100 labels. Args: fsts: List of FSTs to be applied on a sample to verify if the resultant FST obeys the property specified in the function. token_type: The token_type used to derive the FST. samples: The number of input samples to take to verify functionality. assert_function: An assert function with input string FSA and output FST as parameters. This function is run in `pynini.default_token_type` environment. This function raises AssertionError on assert failure. """ input_language = pynini.project(fsts[0], "input") if token_type == "byte": # NOTE: Randgenning directly from the byte machine is bound to lead to # trouble since it can generate things that aren't well-formed UTF-8 # sequences and thus cannot be put into a Python str type. input_language = pynini.intersect(input_language, utf8.VALID_UTF8_CHAR.star) input_samples = pynini.randgen(input_language, npath=samples, max_length=_MAX_SAMPLE_LENGTH) with pynini.default_token_type(token_type): for ilabels in _olabels_iter(input_samples): input_str_fsa = _label_list_to_string_fsa(ilabels) output_fst = rewrite.ComposeFsts([input_str_fsa] + fsts) assert_function(input_str_fsa, output_fst)
#------------------ #lt0 - bb lt0_accept = (not_b.star + b + not_b.star + b + not_b.star).optimize() lt0_accept.write("lt0_accept.fsa") #lt1 - b^4 OR a^4 bbbb = (not_ab.star + b + b + b + b + not_ab.star).optimize() aaaa = (not_ab.star + a + a + a + a + not_ab.star).optimize() lt1_accept = pynini.union(aaaa, bbbb) lt1_accept.write("lt1_accept.fsa") #lt2 - b^4 AND a^4 aaaa_2 = (not_a.star + a + a + a + a + not_a.star).optimize() bbbb_2 = (not_b.star + b + b + b + b + not_b.star).optimize() lt2_accept = (pynini.intersect(aaaa_2, bbbb_2)).optimize() lt2_accept.write("lt2_accept.fsa") #lt3 - if b^8 then a^8 - NEEDS WORK #FIRST - strings with b^8 and a^8 b_8 = (not_b.star + b + b + b + b + b + b + b + b + not_b.star).optimize() a_8 = (not_a.star + a + a + a + a + a + a + a + a + not_a.star).optimize() b8_a8 = (pynini.intersect(b_8, a_8)).optimize() #SECOND - strings with b^7 and a^7 b_7 = (not_b.star + b + b + b + b + b + b + b + not_b.star).optimize() a_7 = (not_a.star + a + a + a + a + a + a + a
def get_pos_string(fsa, min_len, max_len): fsa_dict = {} for i in range(min_len, max_len + 1): fsa_dict[i] = pynini.intersect(fsa, pynini.closure(sigma, i, i)) # print(list_string_set(fsa_dict[i])) return fsa_dict
# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """File access functions used in this package.""" import errno import os import pathlib import pynini from rules_python.python.runfiles import runfiles EMPTY: pynini.Fst = pynini.intersect(pynini.accep("a"), pynini.accep("b")).optimize() EPSILON: pynini.Fst = pynini.accep("").optimize() def AsResourcePath(filename: os.PathLike) -> os.PathLike: filename = os.fspath(filename) return runfiles.Create().Rlocation(filename) def IsFileExist(filename: os.PathLike) -> bool: """Checks if a resource file exists.""" try: filename = AsResourcePath(filename) if os.path.isfile(filename): return True except IOError as ex:
def decode(self, t9_input: pynini.FstLike) -> pynini.Fst: lattice = rewrite.rewrite_lattice(t9_input, self._decoder) return pynini.intersect(lattice, self._lexicon)
sp[1] = sigma4Star - lg_containing_ssq(b, 4) # SP4 , forbidden bbbb sp[2] = sigma4Star - lg_containing_ssq(b, 8) # SP8 , forbidden bbbbbbbb ############### # LT Examples # ############### lt = dict() # LT2 , at least one bb lt[0] = lg_containing_str(b, 2) # LT4 , at least one bbbb or at least one aaaa lt[1] = lg_containing_str(b, 4) + lg_containing_str(a, 4) # LT4 , at least one bbbb and at least one aaaa lt[2] = pynini.intersect(lg_containing_str(b, 4), lg_containing_str(a, 4)) # LT8 , if b^8 then a^8 (~~~ not b^8 or a^8) lt[3] = (sigma4Star - lg_containing_str(b, 8)) + lg_containing_str(a, 8) ############### # PT Examples # ############### pt = dict() # PT2 , at least one bb pt[0] = lg_containing_ssq(b, 2) # PT4 , at least one bbbb or at least one aaaa pt[1] = lg_containing_ssq(b, 4) + lg_containing_ssq(a, 4)
sp[2] = sigma4Star - lg_containing_ssq(b, 8) # SP8 , forbidden bbbbbbbb ############### # LT Examples # ############### lt = dict() # LT2 , at least one bb lt[0] = lg_containing_str(b, 2) # LT4 , at least one bbbb or at least one aaaa lt[1] = pynini.union(lg_containing_str(b, 4), lg_containing_str(a, 4)) # lt[1] = lg_containing_str(b,4) + lg_containing_str(a,4) # LT4 , at least one bbbb and at least one aaaa lt[2] = pynini.intersect(lg_containing_str(b, 4), lg_containing_str(a, 4)) # LT8 , if b^8 then a^8 (~~~ not b^8 or a^8) lt[3] = (sigma4Star - lg_containing_str(b, 8)) | lg_containing_str(a, 8) # aa and ab substrings lt[4] = pynini.intersect(lg_containing_str(a, 2), lg_containing_str(a + b, 1)) # aa and ab substrings (using sigma = {a,b}) lt[5] = pynini.intersect(lg_with_str(sigma2Star, a, 2), lg_with_str(sigma2Star, a + b, 1)) ############### # PT Examples # ###############
def testJoin(self): joined = pynutil.join("a", " ") for i in range(1, 10): query = " ".join(["a"] * i) lattice = pynini.intersect(joined, query) self.assertNotEqual(lattice.start(), pynini.NO_STATE_ID)
def all_single_byte_substrings(self, fsa: pynini.Fst) -> pynini.Fst: return pynini.intersect(self.all_substrings(fsa), byte.BYTE).optimize()