示例#1
0
    def matches(self,
                istring: pynini.FstLike,
                ostring: pynini.FstLike,
                input_token_type: Optional[pynini.TokenType] = None,
                output_token_type: Optional[pynini.TokenType] = None) -> bool:
        """Returns whether or not the rule cascade allows an input/output pair.

    Args:
      istring: Input string or FST.
      ostring: Output string or FST.
      input_token_type: Optional input token type, or symbol table.
      output_token_type: Optional output token type, or symbol table.

    Returns:
      Whether the input-output pair is generated by the rule.
    """
        lattice = self._rewrite_lattice(istring, input_token_type)
        # TODO(kbg): Consider using `contextlib.nullcontext` here instead.
        if output_token_type is None:
            lattice = pynini.intersect(lattice,
                                       ostring,
                                       compose_filter="sequence")
        else:
            with pynini.default_token_type(output_token_type):
                lattice = pynini.intersect(lattice,
                                           ostring,
                                           compose_filter="sequence")
        return lattice.start() != pynini.NO_STATE_ID
示例#2
0
def gen(fsa, accept):
    R = functools.partial(pynini.randgen)
    loop = 10
    n = 100
    for i in range(loop):
        num = int(n + n*i*0.1)
        rand = R(pynini.intersect(fsa, accept), npath=num, seed=0, select="uniform", max_length=100, weighted=False)
    return list_string_set(rand)
示例#3
0
文件: util.py 项目: danieldeutsch/gcd
def _intersect_fsa_fsa(fsa1: FSA, fsa2: FSA) -> FSA:
    intersection = FSA()
    intersection.fst = pynini.intersect(fsa1.fst, fsa2.fst)
    intersection.symbol_table = fsa1.symbol_table
    intersection.token_to_key = fsa1.token_to_key
    intersection.key_to_token = fsa1.key_to_token
    intersection.compile()
    return intersection
 def get_best_expansion(self, expansions):
     print("combining expansions and LM ...")
     best_exp = pn.intersect(expansions, self.LM)
     print("optimizing intersection ...")
     best_exp.optimize()
     #best_exp.draw('best.dot')
     shortest_path = pn.shortestpath(best_exp, nshortest=1).optimize()
     #shortest_path.draw('shortest.dot')
     return shortest_path
示例#5
0
def getPosString(fsa, min_len, max_len):
    fsa_dict = {}
    pos_str_dict = {}
    for i in range(min_len, max_len + 1):
        fsa_dict[i] = pynini.intersect(fsa, pynini.closure(sigma, i, i))
        pos_str_dict[i] = list(
            np.random.permutation(listStringSet(fsa_dict[i])))
        print(pos_str_dict[i])
    return pos_str_dict
示例#6
0
文件: util.py 项目: danieldeutsch/gcd
def _intersect_pda_pda(pda1: PDA, pda2: PDA) -> PDA:
    intersection = PDA()
    intersection.fst = pynini.intersect(pda1.fst, pda2.fst)
    intersection.symbol_table = pda1.symbol_table
    intersection.parens = pda1.parens
    intersection.open_key_to_close_key = pda1.open_key_to_close_key
    intersection.close_key_to_open_key = pda1.close_key_to_open_key
    intersection.token_to_key = pda1.token_to_key
    intersection.key_to_token = pda1.key_to_token
    intersection.compile()
    return intersection
 def __call__(self, string: str) -> str:
     try:
         lattice = self._rewrite_s1(string)
     except rewrite.Error:
         return "<composition failure>"
     filtered = pynini.intersect(lattice, self._lexicon)
     # If intersection fails, take the top string from the original lattice.
     # But if it succeeds, take the top string from the filtered lattice.
     if filtered.start() == pynini.NO_STATE_ID:
         return self._rewrite_s2(lattice)
     else:
         return self._rewrite_s2(filtered)
示例#8
0
    def _language_model_scoring(self, verbal_arr):

        #word_fst, self.oov_queue = self.compiler.fst_stringcompile_words(verbal_arr)
        word_fst, self.replacement_dict = self.compiler.fst_stringcompile_words(
            verbal_arr)
        #self.replacement_dict = self.compiler.replacement_dict
        word_fst.set_output_symbols(self.word_symbols)
        word_fst.optimize()
        word_fst.project(True)
        word_fst.arcsort()
        #word_fst.draw('word_fst.dot')
        lm_intersect = pn.intersect(word_fst, self.lm)
        lm_intersect.optimize()
        #lm_intersect.draw('lm_intersect.dot')
        shortest_path = pn.shortestpath(lm_intersect).optimize()
        return shortest_path
示例#9
0
    def _assert_fst_sampled_behavior(
            self, fsts: List[pynini.Fst], token_type: pynini.TokenType,
            samples: int, assert_function: Callable[[pynini.Fst, pynini.Fst],
                                                    None]) -> None:
        """Asserts that FST composed on samples is follow a specific behavior.

    This samples from first FST's input projection in order to assert a
    behavior when composed with the FSTs. This is used in lieu of statically
    verifying that this composition has a specific property as that isn't easy
    to answer for non-deterministic FSTs. If token_type is set to "byte", then
    the input projection of the FST is intersected with the definition of the
    closure over valid UTF-8 characters to ensure all samples are valid UTF-8
    strings that Python can handle. The maximum length of a sample is set to 100
    labels.

    Args:
      fsts: List of FSTs to be applied on a sample to verify if the resultant
          FST obeys the property specified in the function.
      token_type: The token_type used to derive the FST.
      samples: The number of input samples to take to verify functionality.
      assert_function: An assert function with  input string FSA and output FST
          as parameters. This function is run in `pynini.default_token_type`
          environment. This function raises AssertionError on assert failure.
    """
        input_language = pynini.project(fsts[0], "input")
        if token_type == "byte":
            # NOTE: Randgenning directly from the byte machine is bound to lead to
            # trouble since it can generate things that aren't well-formed UTF-8
            # sequences and thus cannot be put into a Python str type.
            input_language = pynini.intersect(input_language,
                                              utf8.VALID_UTF8_CHAR.star)
        input_samples = pynini.randgen(input_language,
                                       npath=samples,
                                       max_length=_MAX_SAMPLE_LENGTH)
        with pynini.default_token_type(token_type):
            for ilabels in _olabels_iter(input_samples):
                input_str_fsa = _label_list_to_string_fsa(ilabels)
                output_fst = rewrite.ComposeFsts([input_str_fsa] + fsts)
                assert_function(input_str_fsa, output_fst)
示例#10
0
#------------------

#lt0 - bb
lt0_accept = (not_b.star + b + not_b.star + b + not_b.star).optimize()
lt0_accept.write("lt0_accept.fsa")

#lt1 - b^4 OR a^4
bbbb = (not_ab.star + b + b + b + b + not_ab.star).optimize()
aaaa = (not_ab.star + a + a + a + a + not_ab.star).optimize()
lt1_accept = pynini.union(aaaa, bbbb)
lt1_accept.write("lt1_accept.fsa")

#lt2 - b^4 AND a^4
aaaa_2 = (not_a.star + a + a + a + a + not_a.star).optimize()
bbbb_2 = (not_b.star + b + b + b + b + not_b.star).optimize()
lt2_accept = (pynini.intersect(aaaa_2, bbbb_2)).optimize()
lt2_accept.write("lt2_accept.fsa")

#lt3 - if b^8 then a^8 - NEEDS WORK 
#FIRST - strings with b^8 and a^8

b_8 = (not_b.star + b + b + b + b + b + b + b + b
            + not_b.star).optimize()
a_8 = (not_a.star + a + a + a + a + a + a + a + a
            + not_a.star).optimize()
b8_a8 = (pynini.intersect(b_8, a_8)).optimize()

#SECOND - strings with b^7 and a^7
b_7 = (not_b.star + b + b + b + b + b + b + b 
            + not_b.star).optimize()
a_7 = (not_a.star + a + a + a + a + a + a + a
示例#11
0
def get_pos_string(fsa, min_len, max_len):
    fsa_dict = {}
    for i in range(min_len, max_len + 1):
        fsa_dict[i] = pynini.intersect(fsa, pynini.closure(sigma, i, i))
        # print(list_string_set(fsa_dict[i]))
    return fsa_dict
示例#12
0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""File access functions used in this package."""

import errno
import os
import pathlib

import pynini
from rules_python.python.runfiles import runfiles

EMPTY: pynini.Fst = pynini.intersect(pynini.accep("a"),
                                     pynini.accep("b")).optimize()
EPSILON: pynini.Fst = pynini.accep("").optimize()


def AsResourcePath(filename: os.PathLike) -> os.PathLike:
    filename = os.fspath(filename)
    return runfiles.Create().Rlocation(filename)


def IsFileExist(filename: os.PathLike) -> bool:
    """Checks if a resource file exists."""
    try:
        filename = AsResourcePath(filename)
        if os.path.isfile(filename):
            return True
    except IOError as ex:
示例#13
0
文件: t9.py 项目: yzhang123/pynini
 def decode(self, t9_input: pynini.FstLike) -> pynini.Fst:
     lattice = rewrite.rewrite_lattice(t9_input, self._decoder)
     return pynini.intersect(lattice, self._lexicon)
示例#14
0
sp[1] = sigma4Star - lg_containing_ssq(b, 4)  # SP4 , forbidden bbbb
sp[2] = sigma4Star - lg_containing_ssq(b, 8)  # SP8 , forbidden bbbbbbbb

###############
# LT Examples #
###############

lt = dict()
# LT2 , at least one bb
lt[0] = lg_containing_str(b, 2)

# LT4 , at least one bbbb or at least one aaaa
lt[1] = lg_containing_str(b, 4) + lg_containing_str(a, 4)

# LT4 , at least one bbbb and at least one aaaa
lt[2] = pynini.intersect(lg_containing_str(b, 4), lg_containing_str(a, 4))

# LT8 , if b^8 then a^8 (~~~ not b^8 or a^8)
lt[3] = (sigma4Star - lg_containing_str(b, 8)) + lg_containing_str(a, 8)

###############
# PT Examples #
###############

pt = dict()
# PT2 , at least one bb
pt[0] = lg_containing_ssq(b, 2)

# PT4 , at least one bbbb or at least one aaaa
pt[1] = lg_containing_ssq(b, 4) + lg_containing_ssq(a, 4)
示例#15
0
sp[2] = sigma4Star - lg_containing_ssq(b, 8)  # SP8 , forbidden bbbbbbbb

###############
# LT Examples #
###############

lt = dict()
# LT2 , at least one bb
lt[0] = lg_containing_str(b, 2)

# LT4 , at least one bbbb or at least one aaaa
lt[1] = pynini.union(lg_containing_str(b, 4), lg_containing_str(a, 4))
# lt[1] = lg_containing_str(b,4) + lg_containing_str(a,4)

# LT4 , at least one bbbb and at least one aaaa
lt[2] = pynini.intersect(lg_containing_str(b, 4), lg_containing_str(a, 4))

# LT8 , if b^8 then a^8 (~~~ not b^8 or a^8)
lt[3] = (sigma4Star - lg_containing_str(b, 8)) | lg_containing_str(a, 8)

# aa and ab substrings
lt[4] = pynini.intersect(lg_containing_str(a, 2), lg_containing_str(a + b, 1))

# aa and ab substrings (using sigma = {a,b})
lt[5] = pynini.intersect(lg_with_str(sigma2Star, a, 2),
                         lg_with_str(sigma2Star, a + b, 1))

###############
# PT Examples #
###############
示例#16
0
 def testJoin(self):
     joined = pynutil.join("a", " ")
     for i in range(1, 10):
         query = " ".join(["a"] * i)
         lattice = pynini.intersect(joined, query)
         self.assertNotEqual(lattice.start(), pynini.NO_STATE_ID)
示例#17
0
 def all_single_byte_substrings(self, fsa: pynini.Fst) -> pynini.Fst:
     return pynini.intersect(self.all_substrings(fsa), byte.BYTE).optimize()