def __init__(self): self._sorted_nodes = [Lattice.Node(id) for id in range(5)] self._sorted_nodes[0].time = 0.0 self._sorted_nodes[1].time = 1.0 self._sorted_nodes[2].time = 1.0 self._sorted_nodes[3].time = None self._sorted_nodes[4].time = 3.0 self._tokens = [[LatticeDecoder.Token()], [LatticeDecoder.Token()], [ LatticeDecoder.Token(), LatticeDecoder.Token(), LatticeDecoder.Token() ], [LatticeDecoder.Token()], []] self._tokens[0][0].total_logprob = -10.0 self._tokens[0][0].recombination_hash = 1 self._sorted_nodes[0].best_logprob = -10.0 self._tokens[1][0].total_logprob = -20.0 self._tokens[1][0].recombination_hash = 1 self._sorted_nodes[1].best_logprob = -20.0 self._tokens[2][0].total_logprob = -30.0 self._tokens[2][0].recombination_hash = 1 self._tokens[2][1].total_logprob = -50.0 self._tokens[2][1].recombination_hash = 2 self._tokens[2][2].total_logprob = -70.0 self._tokens[2][2].recombination_hash = 3 self._sorted_nodes[2].best_logprob = -30.0 self._tokens[3][0].total_logprob = -100.0 self._tokens[3][0].recombination_hash = 1 self._sorted_nodes[3].best_logprob = -100.0
def test_read_slf_node(self): lattice = SLFLattice(None) lattice.nodes = [Lattice.Node(id) for id in range(5)] lattice._read_slf_node(0, []) lattice._read_slf_node(1, ['t=1.0']) lattice._read_slf_node(2, ['time=2.1']) lattice._read_slf_node(3, ['t=3.0', 'WORD=wo rd']) lattice._read_slf_node(4, ['time=4.1', 'W=word']) self.assertEqual(lattice.nodes[1].time, 1.0) self.assertEqual(lattice.nodes[2].time, 2.1) self.assertEqual(lattice.nodes[3].time, 3.0) self.assertEqual(lattice.nodes[3].word, 'wo rd') self.assertEqual(lattice.nodes[4].time, 4.1) self.assertEqual(lattice.nodes[4].word, 'word')
def test_read_slf_link(self): lattice = SLFLattice(None) lattice.nodes = [Lattice.Node(id) for id in range(4)] lattice.links = [] lattice._read_slf_node(0, ['t=0.0']) lattice._read_slf_node(1, ['t=1.0']) lattice._read_slf_node(2, ['t=2.0']) lattice._read_slf_node(3, ['t=3.0']) lattice._read_slf_link(0, ['START=0', 'END=1']) lattice._read_slf_link(1, ['S=1', 'E=2', 'WORD=wo rd', 'acoustic=-0.1', 'language=-0.2']) lattice._read_slf_link(2, ['S=2', 'E=3', 'W=word', 'a=-0.3', 'l=-0.4']) lattice._read_slf_link(3, ['S=1', 'E=3', 'a=-0.5', 'l=-0.6']) self.assertTrue(lattice.links[0].start_node is lattice.nodes[0]) self.assertTrue(lattice.links[0].end_node is lattice.nodes[1]) self.assertTrue(lattice.links[1].start_node is lattice.nodes[1]) self.assertTrue(lattice.links[1].end_node is lattice.nodes[2]) self.assertEqual(lattice.links[1].word, 'wo rd') self.assertEqual(lattice.links[1].ac_logprob, -0.1) self.assertEqual(lattice.links[1].lm_logprob, -0.2) self.assertTrue(lattice.links[2].start_node is lattice.nodes[2]) self.assertTrue(lattice.links[2].end_node is lattice.nodes[3]) self.assertEqual(lattice.links[2].word, 'word') self.assertEqual(lattice.links[2].ac_logprob, -0.3) self.assertEqual(lattice.links[2].lm_logprob, -0.4) self.assertTrue(lattice.links[3].start_node is lattice.nodes[1]) self.assertTrue(lattice.links[3].end_node is lattice.nodes[3]) self.assertEqual(lattice.links[3].ac_logprob, -0.5) self.assertEqual(lattice.links[3].lm_logprob, -0.6) self.assertEqual(len(lattice.nodes[0].in_links), 0) self.assertEqual(len(lattice.nodes[0].out_links), 1) self.assertEqual(len(lattice.nodes[1].in_links), 1) self.assertEqual(len(lattice.nodes[1].out_links), 2) self.assertEqual(len(lattice.nodes[2].in_links), 1) self.assertEqual(len(lattice.nodes[2].out_links), 1) self.assertEqual(len(lattice.nodes[3].in_links), 2) self.assertEqual(len(lattice.nodes[3].out_links), 0) self.assertEqual(lattice.nodes[0].out_links[0].end_node.time, 1.0) self.assertEqual(lattice.nodes[1].in_links[0].start_node.time, 0.0) self.assertEqual(lattice.nodes[1].out_links[0].end_node.time, 2.0) self.assertEqual(lattice.nodes[1].out_links[1].end_node.time, 3.0) self.assertEqual(lattice.nodes[2].in_links[0].start_node.time, 1.0) self.assertEqual(lattice.nodes[2].out_links[0].end_node.time, 3.0) self.assertEqual(lattice.nodes[3].in_links[0].start_node.time, 2.0) self.assertEqual(lattice.nodes[3].in_links[1].start_node.time, 1.0)
def test_move_words_to_links(self): lattice = SLFLattice(None) lattice.nodes = [Lattice.Node(id) for id in range(5)] lattice.nodes[0].word = 'A' lattice.nodes[1].word = 'B' lattice.nodes[2].word = 'C' lattice.nodes[3].word = 'D' lattice.nodes[4].word = 'E' lattice.initial_node = lattice.nodes[0] lattice.final_node = lattice.nodes[4] lattice._add_link(lattice.nodes[0], lattice.nodes[1]) lattice._add_link(lattice.nodes[0], lattice.nodes[2]) lattice._add_link(lattice.nodes[1], lattice.nodes[3]) lattice._add_link(lattice.nodes[2], lattice.nodes[3]) lattice._add_link(lattice.nodes[3], lattice.nodes[4]) lattice._move_words_to_links() self.assertEqual(lattice.links[0].word, 'B') self.assertEqual(lattice.links[1].word, 'C') self.assertEqual(lattice.links[2].word, 'D') self.assertEqual(lattice.links[3].word, 'D') self.assertEqual(lattice.links[4].word, 'E') for node in lattice.nodes: self.assertFalse(hasattr(node, 'word'))
def test_sorted_nodes(self): lattice = Lattice() lattice.nodes = [Lattice.Node(id) for id in range(9)] lattice.nodes[0].time = 0.0 lattice.nodes[2].time = 1.0 lattice.nodes[4].time = 2.0 lattice.nodes[3].time = 3.0 lattice.nodes[5].time = 4.0 lattice.nodes[1].time = 4.0 lattice.nodes[6].time = 5.0 lattice.nodes[7].time = None lattice.nodes[8].time = -1.0 lattice._add_link(lattice.nodes[0], lattice.nodes[2]) lattice._add_link(lattice.nodes[0], lattice.nodes[4]) lattice._add_link(lattice.nodes[2], lattice.nodes[3]) lattice._add_link(lattice.nodes[4], lattice.nodes[3]) lattice._add_link(lattice.nodes[2], lattice.nodes[5]) lattice._add_link(lattice.nodes[3], lattice.nodes[5]) lattice._add_link(lattice.nodes[5], lattice.nodes[1]) lattice._add_link(lattice.nodes[5], lattice.nodes[6]) lattice._add_link(lattice.nodes[5], lattice.nodes[7]) lattice._add_link(lattice.nodes[1], lattice.nodes[8]) lattice._add_link(lattice.nodes[6], lattice.nodes[8]) lattice._add_link(lattice.nodes[7], lattice.nodes[8]) lattice.initial_node = lattice.nodes[0] lattice.final_node = lattice.nodes[8] sorted_nodes = lattice.sorted_nodes() self.assertEqual(sorted_nodes[0].id, 0) self.assertEqual(sorted_nodes[1].id, 2) self.assertEqual(sorted_nodes[2].id, 4) self.assertEqual(sorted_nodes[3].id, 3) self.assertEqual(sorted_nodes[4].id, 5) # Topologically equal nodes will be sorted in ascending time. The nodes # that don't have time will go last. self.assertEqual(sorted_nodes[5].id, 1) self.assertEqual(sorted_nodes[6].id, 6) self.assertEqual(sorted_nodes[7].id, 7) self.assertEqual(sorted_nodes[8].id, 8) with open(self.lattice_path, 'r') as lattice_file: lattice = SLFLattice(lattice_file) def reachable(initial_node, node): result = False for link in initial_node.out_links: if link.end_node is node: result = True return result sorted_nodes = lattice.sorted_nodes() for left_node, right_node in zip(sorted_nodes, sorted_nodes[1:]): if (not left_node.time is None) and (not right_node.time is None): self.assertLessEqual(left_node.time, right_node.time) self.assertFalse(reachable(right_node, left_node))