def test_single_fst(self): labels = [2, 5, 8] aux_labels = [3, 6, 9] fst = k2.linear_fst(labels, aux_labels) assert len(fst.shape) == 2 assert fst.shape[0] == len(labels) + 2, 'There should be 5 states' assert fst.aux_labels.shape[0] == len(aux_labels) + 1 assert torch.all( torch.eq(fst.scores, torch.zeros(len(labels) + 1, dtype=torch.float32))) assert torch.all( torch.eq( fst.arcs.values()[:, :-1], # skip the last field `scores` torch.tensor([[0, 1, 2], [1, 2, 5], [2, 3, 8], [3, 4, -1]], dtype=torch.int32))) assert torch.all( torch.eq(fst.aux_labels, torch.tensor(aux_labels + [-1], dtype=torch.int32)))
def test_fst_vec(self): labels = [ [1, 3, 5], [2, 6], [8, 7, 9], ] aux_labels = [ [2, 4, 6], [6, 2], [8, 7, 9], ] num_labels = sum([len(s) for s in labels]) fst = k2.linear_fst(labels, aux_labels) assert len(fst.shape) == 3 assert fst.shape[0] == 3, 'There should be 3 FSTs' assert torch.all( torch.eq( fst.scores, torch.zeros(num_labels + len(labels), dtype=torch.float32))) expected_aux_labels = [2, 4, 6, -1, 6, 2, -1, 8, 7, 9, -1] assert torch.all( torch.eq(fst.aux_labels, torch.tensor(expected_aux_labels, dtype=torch.int32)))
def visualize_ctc_topo(): '''This function shows how to visualize standard/modified ctc topologies. It's for demonstration only, not for testing. ''' max_token = 2 labels_sym = k2.SymbolTable.from_str(''' <blk> 0 z 1 o 2 ''') aux_labels_sym = k2.SymbolTable.from_str(''' z 1 o 2 ''') word_sym = k2.SymbolTable.from_str(''' zoo 1 ''') standard = k2.ctc_topo(max_token, modified=False) modified = k2.ctc_topo(max_token, modified=True) standard.labels_sym = labels_sym standard.aux_labels_sym = aux_labels_sym modified.labels_sym = labels_sym modified.aux_labels_sym = aux_labels_sym standard.draw('standard_topo.svg', title='standard CTC topo') modified.draw('modified_topo.svg', title='modified CTC topo') fsa = k2.linear_fst([1, 2, 2], [1, 0, 0]) fsa.labels_sym = labels_sym fsa.aux_labels_sym = word_sym fsa.draw('transcript.svg', title='transcript') standard_graph = k2.compose(standard, fsa) modified_graph = k2.compose(modified, fsa) standard_graph.draw('standard_graph.svg', title='standard graph') modified_graph.draw('modified_graph.svg', title='modified graph') # z z <blk> <blk> o o <blk> o <blk> inputs = k2.linear_fsa([1, 1, 0, 0, 2, 2, 0, 2, 0]) inputs.labels_sym = labels_sym inputs.draw('inputs.svg', title='inputs') standard_lattice = k2.intersect(standard_graph, inputs, treat_epsilons_specially=False) standard_lattice.draw('standard_lattice.svg', title='standard lattice') modified_lattice = k2.intersect(modified_graph, inputs, treat_epsilons_specially=False) modified_lattice = k2.connect(modified_lattice) modified_lattice.draw('modified_lattice.svg', title='modified lattice') # z z <blk> <blk> o o o <blk> inputs2 = k2.linear_fsa([1, 1, 0, 0, 2, 2, 2, 0]) inputs2.labels_sym = labels_sym inputs2.draw('inputs2.svg', title='inputs2') standard_lattice2 = k2.intersect(standard_graph, inputs2, treat_epsilons_specially=False) standard_lattice2 = k2.connect(standard_lattice2) # It's empty since the topo requires that there must be a blank # between the two o's in zoo assert standard_lattice2.num_arcs == 0 standard_lattice2.draw('standard_lattice2.svg', title='standard lattice2') modified_lattice2 = k2.intersect(modified_graph, inputs2, treat_epsilons_specially=False) modified_lattice2 = k2.connect(modified_lattice2) modified_lattice2.draw('modified_lattice2.svg', title='modified lattice2')