def test_fst_relabel_tables(): fst = VectorFst() s1 = fst.add_state() s2 = fst.add_state() fst.add_tr(s1, Tr(1, 2, weight_one(), s2)) fst.set_start(s1) fst.set_final(s2) old_isymt = SymbolTable.from_symbols(["a", "b"]) new_isymt = SymbolTable.from_symbols(["b", "a"]) old_osymt = SymbolTable.from_symbols(["aa", "bb"]) new_osymt = SymbolTable.from_symbols(["bb", "aa"]) fst_ref = VectorFst() s1 = fst_ref.add_state() s2 = fst_ref.add_state() fst_ref.add_tr(s1, Tr(2, 1, weight_one(), s2)) fst_ref.set_start(s1) fst_ref.set_final(s2) fst_1 = fst.copy() fst_1.relabel_tables( old_isymbols=old_isymt, new_isymbols=new_isymt, attach_new_isymbols=True, old_osymbols=old_osymt, new_osymbols=new_osymt, attach_new_osymbols=True, ) assert fst_1 == fst_ref assert fst_1.input_symbols() == new_isymt assert fst_1.output_symbols() == new_osymt fst_2 = fst.copy() fst_2.relabel_tables( old_isymbols=old_isymt, new_isymbols=new_isymt, attach_new_isymbols=False, old_osymbols=old_osymt, new_osymbols=new_osymt, attach_new_osymbols=False, ) assert fst_2 == fst_ref assert fst_2.input_symbols() is None assert fst_2.output_symbols() is None fst_3 = fst.copy() fst_3.set_input_symbols(old_isymt) fst_3.set_output_symbols(old_osymt) fst_3.relabel_tables( new_isymbols=new_isymt, new_osymbols=new_osymt, ) assert fst_3 == fst_ref assert fst_3.input_symbols() == new_isymt assert fst_3.output_symbols() == new_osymt
def test_symt_copy_add(): fst = VectorFst() symt = SymbolTable.from_symbols(["a", "b"]) fst.set_input_symbols(symt) fst.set_output_symbols(symt) symt2 = fst.input_symbols().copy() symt2.add_symbol("c") assert symt2.num_symbols() == symt.num_symbols() + 1
def test_eq_table(): symt1 = SymbolTable() symt1.add_symbol("a") symt1.add_symbol("b") symt2 = SymbolTable() symt2.add_symbol("a") symt2.add_symbol("b") assert symt1 == symt2
def test_transducer(): symt = SymbolTable() symt.add_symbol("hello") symt.add_symbol("world") symt.add_symbol("coucou") symt.add_symbol("monde") f = transducer("hello world", "coucou monde", symt, symt) d = DrawingConfig() f.draw("acceptor.dot", None, None, d)
def test_fst_with_symt_mut_fail(): fst = VectorFst() # States s1 = fst.add_state() s2 = fst.add_state() fst.set_start(s1) fst.set_final(s2) input_symt = SymbolTable() input_symt.add_symbol("a") input_symt.add_symbol("b") input_symt.add_symbol("c") fst.set_input_symbols(input_symt) output_symt = SymbolTable() fst.set_output_symbols(output_symt) with pytest.raises(Exception) as err: fst.input_symbols().add_symbol("d") assert ( str(err.value) == '`add_symbol` failed: "Could not get a mutable reference to the symbol table"' )
def test_fst_symt(): fst = VectorFst() s1 = fst.add_state() s2 = fst.add_state() fst.set_start(s1) fst.set_final(s2, 1.0) tr_1 = Tr(1, 0, 10.0, s2) tr_2 = Tr(2, 0, 1.0, s1) tr_3 = Tr(3, 0, 1.0, s2) fst.add_tr(s1, tr_1) fst.add_tr(s2, tr_2) fst.add_tr(s2, tr_3) input_symt = SymbolTable() input_symt.add_symbol("a") input_symt.add_symbol("b") input_symt.add_symbol("c") fst.set_input_symbols(input_symt) fst_in_symbols = fst.input_symbols() assert input_symt == fst_in_symbols assert fst_in_symbols.num_symbols() == 4 assert fst_in_symbols.find("a") == 1 assert fst_in_symbols.find("b") == 2 assert fst_in_symbols.find("c") == 3 output_symt = SymbolTable() fst.set_output_symbols(output_symt) fst_out_symbols = fst.output_symbols() assert output_symt == fst_out_symbols assert fst_out_symbols.num_symbols() == 1
def test_symt_add_twice_symbol(): symt = SymbolTable() symt.add_symbol("a") symt.add_symbol("a") assert symt.num_symbols() == 2 assert symt.find("a") == 1
def test_string_paths_iterator(): fst = VectorFst() s1 = fst.add_state() s2 = fst.add_state() fst.set_start(s1) fst.set_final(s2, 2.0) fst.add_tr(s1, Tr(1, 2, 2.0, s2)) fst.add_tr(s1, Tr(2, 3, 3.0, s2)) symt = SymbolTable() symt.add_symbol("a") symt.add_symbol("b") symt.add_symbol("c") fst.set_input_symbols(symt) fst.set_output_symbols(symt) string_paths_it = fst.string_paths() assert not string_paths_it.done() v1 = next(string_paths_it) assert v1.weight() == 4.0 assert v1.istring() == "a" assert v1.ostring() == "b" assert not string_paths_it.done() v2 = next(string_paths_it) assert v2.weight() == 5.0 assert v2.istring() == "b" assert v2.ostring() == "c" assert string_paths_it.done()
def test_acceptor(): symt = SymbolTable() symt.add_symbol("hello") symt.add_symbol("world") f = acceptor("hello world", symt) # Expected FST expected_fst = VectorFst() s1 = expected_fst.add_state() s2 = expected_fst.add_state() s3 = expected_fst.add_state() expected_fst.set_start(s1) expected_fst.set_final(s3) tr1 = Tr(1, 1, None, s2) expected_fst.add_tr(s1, tr1) tr2 = Tr(2, 2, None, s3) expected_fst.add_tr(s2, tr2) assert f == expected_fst
def test_fst_read_write_with_symt(): fst = VectorFst() # States s1 = fst.add_state() s2 = fst.add_state() fst.set_start(s1) fst.set_final(s2) tr_1 = Tr(3, 5, 10.0, s2) tr_2 = Tr(5, 7, 18.0, s2) fst.add_tr(s1, tr_1) fst.add_tr(s1, tr_2) input_symt = SymbolTable() input_symt.add_symbol("a") input_symt.add_symbol("b") input_symt.add_symbol("c") fst.set_input_symbols(input_symt) output_symt = SymbolTable() fst.set_output_symbols(output_symt) fst.write("/tmp/test.fst") read_fst = VectorFst.read("/tmp/test.fst") assert read_fst.input_symbols().num_symbols() == 4 assert read_fst.input_symbols().find("a") == 1 assert read_fst.input_symbols().find("b") == 2 assert read_fst.input_symbols().find("c") == 3 assert read_fst.output_symbols().num_symbols() == 1 assert fst == read_fst
def test_symt_iterator(): symt = SymbolTable() symt.add_symbol("a") symt.add_symbol("b") assert list(symt) == [(0, "<eps>"), (1, "a"), (2, "b")]
def test_symt(): symt = SymbolTable() symt.add_symbol("a") symt.add_symbol("b") assert symt.num_symbols() == 3 assert symt.find(EPS_SYMBOL) == 0 assert symt.find("a") == 1 assert symt.find("b") == 2 assert symt.member(EPS_SYMBOL) is True assert symt.member("a") is True assert symt.member("b") is True assert symt.member("c") is False assert symt.find(0) == EPS_SYMBOL assert symt.find(1) == "a" assert symt.find(2) == "b" assert symt.member(0) is True assert symt.member(1) is True assert symt.member(2) is True assert symt.member(3) is False
def test_add_table(): symt1 = SymbolTable() symt1.add_symbol("a") symt1.add_symbol("b") symt2 = SymbolTable() symt2.add_symbol("c") symt2.add_symbol("b") symt1.add_table(symt2) assert symt1.num_symbols() == 4 assert symt1.find(EPS_SYMBOL) == 0 assert symt1.find("a") == 1 assert symt1.find("b") == 2 assert symt1.find("c") == 3