Ejemplo n.º 1
0
 def test_single_fst(self):
     labels = [2, 5, 8]
     aux_labels = [3, 6, 9]
     fst = k2.linear_fst(labels, aux_labels)
     assert len(fst.shape) == 2
     assert fst.shape[0] == len(labels) + 2, 'There should be 5 states'
     assert fst.aux_labels.shape[0] == len(aux_labels) + 1
     assert torch.all(
         torch.eq(fst.scores,
                  torch.zeros(len(labels) + 1, dtype=torch.float32)))
     assert torch.all(
         torch.eq(
             fst.arcs.values()[:, :-1],  # skip the last field `scores`
             torch.tensor([[0, 1, 2], [1, 2, 5], [2, 3, 8], [3, 4, -1]],
                          dtype=torch.int32)))
     assert torch.all(
         torch.eq(fst.aux_labels,
                  torch.tensor(aux_labels + [-1], dtype=torch.int32)))
Ejemplo n.º 2
0
 def test_fst_vec(self):
     labels = [
         [1, 3, 5],
         [2, 6],
         [8, 7, 9],
     ]
     aux_labels = [
         [2, 4, 6],
         [6, 2],
         [8, 7, 9],
     ]
     num_labels = sum([len(s) for s in labels])
     fst = k2.linear_fst(labels, aux_labels)
     assert len(fst.shape) == 3
     assert fst.shape[0] == 3, 'There should be 3 FSTs'
     assert torch.all(
         torch.eq(
             fst.scores,
             torch.zeros(num_labels + len(labels), dtype=torch.float32)))
     expected_aux_labels = [2, 4, 6, -1, 6, 2, -1, 8, 7, 9, -1]
     assert torch.all(
         torch.eq(fst.aux_labels,
                  torch.tensor(expected_aux_labels, dtype=torch.int32)))
Ejemplo n.º 3
0
    def visualize_ctc_topo():
        '''This function shows how to visualize
        standard/modified ctc topologies. It's for
        demonstration only, not for testing.
        '''
        max_token = 2
        labels_sym = k2.SymbolTable.from_str('''
            <blk> 0
            z 1
            o 2
        ''')
        aux_labels_sym = k2.SymbolTable.from_str('''
            z 1
            o 2
        ''')

        word_sym = k2.SymbolTable.from_str('''
            zoo 1
        ''')

        standard = k2.ctc_topo(max_token, modified=False)
        modified = k2.ctc_topo(max_token, modified=True)
        standard.labels_sym = labels_sym
        standard.aux_labels_sym = aux_labels_sym

        modified.labels_sym = labels_sym
        modified.aux_labels_sym = aux_labels_sym

        standard.draw('standard_topo.svg', title='standard CTC topo')
        modified.draw('modified_topo.svg', title='modified CTC topo')
        fsa = k2.linear_fst([1, 2, 2], [1, 0, 0])
        fsa.labels_sym = labels_sym
        fsa.aux_labels_sym = word_sym
        fsa.draw('transcript.svg', title='transcript')

        standard_graph = k2.compose(standard, fsa)
        modified_graph = k2.compose(modified, fsa)
        standard_graph.draw('standard_graph.svg', title='standard graph')
        modified_graph.draw('modified_graph.svg', title='modified graph')

        # z z <blk> <blk> o o <blk> o <blk>
        inputs = k2.linear_fsa([1, 1, 0, 0, 2, 2, 0, 2, 0])
        inputs.labels_sym = labels_sym
        inputs.draw('inputs.svg', title='inputs')
        standard_lattice = k2.intersect(standard_graph,
                                        inputs,
                                        treat_epsilons_specially=False)
        standard_lattice.draw('standard_lattice.svg', title='standard lattice')

        modified_lattice = k2.intersect(modified_graph,
                                        inputs,
                                        treat_epsilons_specially=False)
        modified_lattice = k2.connect(modified_lattice)
        modified_lattice.draw('modified_lattice.svg', title='modified lattice')

        # z z <blk> <blk> o o o <blk>
        inputs2 = k2.linear_fsa([1, 1, 0, 0, 2, 2, 2, 0])
        inputs2.labels_sym = labels_sym
        inputs2.draw('inputs2.svg', title='inputs2')
        standard_lattice2 = k2.intersect(standard_graph,
                                         inputs2,
                                         treat_epsilons_specially=False)
        standard_lattice2 = k2.connect(standard_lattice2)
        # It's empty since the topo requires that there must be a blank
        # between the two o's in zoo
        assert standard_lattice2.num_arcs == 0
        standard_lattice2.draw('standard_lattice2.svg',
                               title='standard lattice2')

        modified_lattice2 = k2.intersect(modified_graph,
                                         inputs2,
                                         treat_epsilons_specially=False)
        modified_lattice2 = k2.connect(modified_lattice2)
        modified_lattice2.draw('modified_lattice2.svg',
                               title='modified lattice2')