def test_fst_to_bytes(): fst = VectorFst() # States s1 = fst.add_state() s2 = fst.add_state() fst.set_start(s1) fst.set_final(s2) bytes = fst.to_bytes() with NamedTemporaryFile() as f: Path(f.name).write_bytes(bytes) fst_read = VectorFst.read(f.name) assert fst == fst_read
def test_fst_read_write(): fst = VectorFst() # States s1 = fst.add_state() s2 = fst.add_state() fst.set_start(s1) fst.set_final(s2) tr_1 = Tr(3, 5, 10.0, s2) tr_2 = Tr(5, 7, 18.0, s2) fst.add_tr(s1, tr_1) fst.add_tr(s1, tr_2) fst.write("/tmp/test.fst") read_fst = VectorFst.read("/tmp/test.fst") assert fst == read_fst
def test_fst_read_write_with_symt(): fst = VectorFst() # States s1 = fst.add_state() s2 = fst.add_state() fst.set_start(s1) fst.set_final(s2) tr_1 = Tr(3, 5, 10.0, s2) tr_2 = Tr(5, 7, 18.0, s2) fst.add_tr(s1, tr_1) fst.add_tr(s1, tr_2) input_symt = SymbolTable() input_symt.add_symbol("a") input_symt.add_symbol("b") input_symt.add_symbol("c") fst.set_input_symbols(input_symt) output_symt = SymbolTable() fst.set_output_symbols(output_symt) fst.write("/tmp/test.fst") read_fst = VectorFst.read("/tmp/test.fst") assert read_fst.input_symbols().num_symbols() == 4 assert read_fst.input_symbols().find("a") == 1 assert read_fst.input_symbols().find("b") == 2 assert read_fst.input_symbols().find("c") == 3 assert read_fst.output_symbols().num_symbols() == 1 assert fst == read_fst