Example #1
0
    def test_single_fsa(self):
        s = '''
            0 1 1 0.1
            0 2 1 0.2
            1 3 2 0.3
            2 3 3 0.4
            3 4 -1 0.5
            4
        '''
        fsa = k2.Fsa.from_str(s)
        fsa.requires_grad_(True)
        new_fsa = k2.add_epsilon_self_loops(fsa)
        assert torch.allclose(
            new_fsa.arcs.values()[:, :3],
            torch.tensor([
                [0, 0, 0],
                [0, 1, 1],
                [0, 2, 1],
                [1, 1, 0],
                [1, 3, 2],
                [2, 2, 0],
                [2, 3, 3],
                [3, 3, 0],
                [3, 4, -1],
            ]).to(torch.int32))

        assert torch.allclose(
            new_fsa.scores, torch.tensor([0, 0.1, 0.2, 0, 0.3, 0, 0.4, 0,
                                          0.5]))
        scale = torch.arange(new_fsa.scores.numel())
        (new_fsa.scores * scale).sum().backward()
        assert torch.allclose(fsa.scores.grad,
                              torch.tensor([1., 2., 4., 6., 8.]))
Example #2
0
    def build_num_graphs(self, texts: List[str]) -> k2.Fsa:
        '''Convert transcript to an Fsa with the help of lexicon
        and word symbol table.

        Args:
          texts:
            Each element is a transcript containing words separated by spaces.
            For instance, it may be 'HELLO SNOWFALL', which contains
            two words.

        Returns:
          Return an FST (FsaVec) corresponding to the transcript. Its `labels` are
          phone IDs and `aux_labels` are word IDs.
        '''
        word_ids_list = []
        for text in texts:
            word_ids = []
            for word in text.split(' '):
                if word in self.words:
                    word_ids.append(self.words[word])
                else:
                    word_ids.append(self.oov_id)
            word_ids_list.append(word_ids)

        fsa = k2.linear_fsa(word_ids_list, self.device)
        fsa = k2.add_epsilon_self_loops(fsa)
        num_graphs = k2.intersect(self.L_inv,
                                  fsa,
                                  treat_epsilons_specially=False).invert_()
        num_graphs = k2.arc_sort(num_graphs)
        return num_graphs
Example #3
0
 def compile_one_and_cache(self, text: str) -> Fsa:
     tokens = (token if token in self.vocab._sym2id else self.oov
               for token in text.split(' '))
     word_ids = [self.vocab.get(token) for token in tokens]
     fsa = k2.linear_fsa(word_ids)
     decoding_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_()
     decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
     return decoding_graph
Example #4
0
def intersect_with_self_loops(base_graph: 'k2.Fsa', aux_graph: 'k2.Fsa') -> 'k2.Fsa':
    """Intersection helper function.
    """
    assert hasattr(base_graph, "aux_labels")
    assert not hasattr(aux_graph, "aux_labels")
    aux_graph_with_self_loops = k2.arc_sort(k2.add_epsilon_self_loops(aux_graph)).to(base_graph.device)
    result = k2.intersect(k2.arc_sort(base_graph), aux_graph_with_self_loops, treat_epsilons_specially=False,)
    setattr(result, "phones", result.labels)
    return result
Example #5
0
    def compile(self,
                texts: Iterable[str],
                P: k2.Fsa,
                replicate_den: bool = True) -> Tuple[k2.Fsa, k2.Fsa]:
        '''Create numerator and denominator graphs from transcripts
        and the bigram phone LM.

        Args:
          texts:
            A list of transcripts. Within a transcript, words are
            separated by spaces.
          P:
            The bigram phone LM created by :func:`create_bigram_phone_lm`.
          replicate_den:
            If True, the returned den_graph is replicated to match the number
            of FSAs in the returned num_graph; if False, the returned den_graph
            contains only a single FSA
        Returns:
          A tuple (num_graph, den_graph), where

            - `num_graph` is the numerator graph. It is an FsaVec with
              shape `(len(texts), None, None)`.

            - `den_graph` is the denominator graph. It is an FsaVec with the same
              shape of the `num_graph` if replicate_den is True; otherwise, it
              is an FsaVec containing only a single FSA.
        '''
        assert P.device == self.device
        P_with_self_loops = k2.add_epsilon_self_loops(P)

        ctc_topo_P = k2.intersect(self.ctc_topo_inv,
                                  P_with_self_loops,
                                  treat_epsilons_specially=False).invert()

        ctc_topo_P = k2.arc_sort(ctc_topo_P)

        num_graphs = self.build_num_graphs(texts)
        num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops(
            num_graphs)

        num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops)

        num = k2.compose(ctc_topo_P,
                         num_graphs_with_self_loops,
                         treat_epsilons_specially=False)
        num = k2.arc_sort(num)

        ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()])
        if replicate_den:
            indexes = torch.zeros(len(texts),
                                  dtype=torch.int32,
                                  device=self.device)
            den = k2.index_fsa(ctc_topo_P_vec, indexes)
        else:
            den = ctc_topo_P_vec

        return num, den
Example #6
0
def create_decoding_graph(texts, L, symbols):
    word_ids_list = []
    for text in texts:
        filter_text = [
            i if i in symbols._sym2id else '<UNK>' for i in text.split(' ')
        ]
        word_ids = [symbols.get(i) for i in filter_text]
        word_ids_list.append(word_ids)
    fsa = k2.linear_fsa(word_ids_list)
    decoding_graph = k2.intersect(fsa, L).invert_()
    decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
    return decoding_graph
    def test_two_fsas(self):
        s1 = '''
            0 1 1 0.1
            0 2 1 0.2
            1 3 2 0.3
            2 3 3 0.4
            3 4 -1 0.5
            4
        '''
        s2 = '''
            0 1 1 0.1
            0 2 2 0.2
            1 2 3 0.3
            2 3 -1 0.4
            3
        '''

        for device in self.devices:
            fsa1 = k2.Fsa.from_str(s1).to(device)
            fsa2 = k2.Fsa.from_str(s2).to(device)

            fsa1.requires_grad_(True)
            fsa2.requires_grad_(True)

            fsa_vec = k2.create_fsa_vec([fsa1, fsa2])
            new_fsa_vec = k2.add_epsilon_self_loops(fsa_vec)
            assert torch.all(
                torch.eq(
                    new_fsa_vec.arcs.values()[:, :3],
                    torch.tensor([[0, 0, 0], [0, 1, 1], [0, 2, 1], [1, 1, 0],
                                  [1, 3, 2], [2, 2, 0], [2, 3, 3], [3, 3, 0],
                                  [3, 4, -1], [0, 0, 0], [0, 1, 1], [0, 2, 2],
                                  [1, 1, 0], [1, 2, 3], [2, 2, 0], [2, 3, -1]],
                                 dtype=torch.int32,
                                 device=device)))

            assert torch.allclose(
                new_fsa_vec.scores,
                torch.tensor([
                    0, 0.1, 0.2, 0, 0.3, 0, 0.4, 0, 0.5, 0, 0.1, 0.2, 0, 0.3,
                    0, 0.4
                ]).to(device))

            scale = torch.arange(new_fsa_vec.scores.numel(), device=device)
            (new_fsa_vec.scores * scale).sum().backward()

            assert torch.allclose(
                fsa1.scores.grad,
                torch.tensor([1., 2., 4., 6., 8.], device=device))

            assert torch.allclose(
                fsa2.scores.grad,
                torch.tensor([10., 11., 13., 15.], device=device))
Example #8
0
    def compile(self, texts: Iterable[str],
                P: k2.Fsa) -> Tuple[k2.Fsa, k2.Fsa, k2.Fsa]:
        '''Create numerator and denominator graphs from transcripts
        and the bigram phone LM.

        Args:
          texts:
            A list of transcripts. Within a transcript, words are
            separated by spaces.
          P:
            The bigram phone LM created by :func:`create_bigram_phone_lm`.
        Returns:
          A tuple (num_graph, den_graph, decoding_graph), where

            - `num_graph` is the numerator graph. It is an FsaVec with
              shape `(len(texts), None, None)`.
              It is the result of compose(ctc_topo, P, L, transcript)

            - `den_graph` is the denominator graph. It is an FsaVec with the same
              shape of the `num_graph`.
              It is the result of compose(ctc_topo, P).

            - decoding_graph: It is the result of compose(ctc_topo, L_disambig, G)
              Note that it is a single Fsa, not an FsaVec.
        '''
        assert P.device == self.device
        P_with_self_loops = k2.add_epsilon_self_loops(P)

        ctc_topo_P = k2.intersect(self.ctc_topo_inv,
                                  P_with_self_loops,
                                  treat_epsilons_specially=False).invert()
        ctc_topo_P = k2.arc_sort(ctc_topo_P)

        num_graphs = self.build_num_graphs(texts)

        num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops(
            num_graphs)

        num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops)

        num = k2.compose(ctc_topo_P,
                         num_graphs_with_self_loops,
                         treat_epsilons_specially=False,
                         inner_labels='phones')
        num = k2.arc_sort(num)

        ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()])
        indexes = torch.zeros(len(texts),
                              dtype=torch.int32,
                              device=self.device)
        den = k2.index_fsa(ctc_topo_P_vec, indexes)

        return num, den, self.decoding_graph
Example #9
0
def create_decoding_graph(texts, graph, symbols):
    fsas = []
    for text in texts:
        filter_text = [
            i if i in symbols._sym2id else '<UNK>' for i in text.split(' ')
        ]
        word_ids = [symbols.get(i) for i in filter_text]
        fsa = k2.linear_fsa(word_ids)
        fsa = k2.arc_sort(fsa)
        decoding_graph = k2.intersect(fsa, graph).invert_()
        decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
        fsas.append(decoding_graph)
    return k2.create_fsa_vec(fsas)
Example #10
0
    def __init__(self,
                 lexicon: Lexicon,
                 P: k2.Fsa,
                 device: torch.device,
                 oov: str = '<UNK>'):
        '''
        Args:
          L_inv:
            Its labels are words, while its aux_labels are phones.
          P:
            A phone bigram LM if the pronunciations in the lexicon are in phones;
            a word piece bigram if the pronunciations in the lexicon are word pieces.
          phones:
            The phone symbol table.
          words:
            The word symbol table.
          oov:
            Out of vocabulary word.
        '''
        self.lexicon = lexicon
        L_inv = self.lexicon.L_inv.to(device)
        P = P.to(device)

        if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0:
            L_inv = k2.arc_sort(L_inv)

        assert L_inv.requires_grad is False

        assert oov in self.lexicon.words

        self.L_inv = L_inv
        self.oov_id = self.lexicon.words[oov]
        self.oov = oov
        self.device = device

        phone_symbols = get_phone_symbols(self.lexicon.phones)
        phone_symbols_with_blank = [0] + phone_symbols

        ctc_topo = build_ctc_topo(phone_symbols_with_blank).to(device)
        assert ctc_topo.requires_grad is False

        ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())

        P_with_self_loops = k2.add_epsilon_self_loops(P)

        ctc_topo_P = k2.intersect(ctc_topo_inv,
                                  P_with_self_loops,
                                  treat_epsilons_specially=False).invert()

        self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
Example #11
0
def nbest_am_lm_scores(
    lats: k2.Fsa,
    num_paths: int,
    device: str = "cuda",
    batch_size: int = 500,
):
    """Compute am scores with word_seqs

    Compatible with both ctc_decoding or TLG decoding.
    """
    paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)
    if isinstance(lats.aux_labels, torch.Tensor):
        word_seqs = k2.ragged.index(lats.aux_labels.contiguous(), paths)
    else:
        # '_k2.RaggedInt' object has no attribute 'contiguous'
        word_seqs = lats.aux_labels.index(paths)
        word_seqs = word_seqs.remove_axis(word_seqs.num_axes - 2)

    # With ctc_decoding, word_seqs stores token_ids.
    # With TLG decoding, word_seqs stores word_ids.
    word_seqs = word_seqs.remove_values_leq(0)
    unique_word_seqs, num_repeats, new2old = word_seqs.unique(
        need_num_repeats=True, need_new2old_indexes=True
    )

    seq_to_path_shape = unique_word_seqs.shape.get_layer(0)
    path_to_seq_map = seq_to_path_shape.row_ids(1)
    # used to split final computed tot_scores
    seq_to_path_splits = seq_to_path_shape.row_splits(1)

    unique_word_seqs = unique_word_seqs.remove_axis(0)
    word_fsas = k2.linear_fsa(unique_word_seqs)

    word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas)

    am_scores, lm_scores = compute_am_scores_and_lm_scores(
        lats, word_fsas_with_epsilon_loops, path_to_seq_map, device, batch_size
    )

    token_seqs = k2.ragged.index(lats.labels.contiguous(), paths)
    token_seqs = token_seqs.remove_axis(0)

    token_ids, _ = token_seqs.index(new2old, axis=0)
    token_ids = token_ids.tolist()
    # Now remove repeated tokens and 0s and -1s.
    token_ids = [remove_repeated_and_leq(tokens) for tokens in token_ids]
    return am_scores, lm_scores, token_ids, new2old, path_to_seq_map, seq_to_path_splits
Example #12
0
def compile_LG(L: Fsa, G: Fsa, labels_disambig_id_start: int,
               aux_labels_disambig_id_start: int) -> Fsa:
    """
    Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``.
    Involves arc sorting, intersection, determinization, removal of disambiguation symbols
    and adding epsilon self-loops.

    Args:
        L:
            An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols``
                and words as ``aux_symbols``.
        G:
            An ``Fsa`` that represents the language model (G), i.e. it's an acceptor
            with words as ``symbols``.
        labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            phonetic alphabet.
        aux_labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            words vocabulary.
    :return:
    """
    L_inv = k2.arc_sort(L.invert_())
    G = k2.arc_sort(G)
    logging.debug("Intersecting L and G")
    LG = k2.intersect(L_inv, G)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting L*G")
    LG = k2.connect(LG).invert_()
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Determinizing L*G")
    LG = k2.determinize(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting det(L*G)")
    LG = k2.connect(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Removing disambiguation symbols on L*G")
    LG.labels[LG.labels >= labels_disambig_id_start] = 0
    LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0
    LG = k2.add_epsilon_self_loops(LG)
    LG = k2.arc_sort(LG)
    logging.debug(
        f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}'
    )
    return LG
Example #13
0
def create_decoding_graph(texts, L, symbols):
    fsas = []
    for text in texts:
        filter_text = [
            i if i in symbols._sym2id else '<UNK>' for i in text.split(' ')
        ]
        word_ids = [symbols.get(i) for i in filter_text]
        fsa = k2.linear_fsa(word_ids)
        print("linear fsa is ", fsa)
        fsa = k2.arc_sort(fsa)
        print("linear fsa, arc-sorted, is ", fsa)
        print("begin")
        print(k2.is_arc_sorted(k2.get_properties(fsa)))
        decoding_graph = k2.intersect(fsa, L).invert_()
        print("linear fsa, composed, is ", fsa)
        print("decoding graph is ", decoding_graph)
        decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
        print("decoding graph with self-loops is ", decoding_graph)
        fsas.append(decoding_graph)
    return k2.create_fsa_vec(fsas)
Example #14
0
def get_hierarchical_targets(ys: List[List[int]],
                             lexicon: k2.Fsa) -> List[Tensor]:
    """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts).

    Args:
        ys: Word level transcripts.
        lexicon: Its labels are words, while its aux_labels are phones.

    Returns:
        List[Tensor]: Phone level transcripts.

    """

    if lexicon is None:
        return ys
    else:
        L_inv = lexicon

    n_batch = len(ys)
    indices = torch.tensor(range(n_batch))
    device = L_inv.device

    transcripts = k2.create_fsa_vec(
        [k2.linear_fsa(x, device=device) for x in ys])
    transcripts_with_self_loops = k2.add_epsilon_self_loops(transcripts)

    transcripts_lexicon = k2.intersect(L_inv,
                                       transcripts_with_self_loops,
                                       treat_epsilons_specially=False)
    # Don't call invert_() above because we want to return phone IDs,
    # which is the `aux_labels` of transcripts_lexicon
    transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon)
    transcripts_lexicon = k2.top_sort(transcripts_lexicon)

    transcripts_lexicon = k2.shortest_path(transcripts_lexicon,
                                           use_double_scores=True)

    ys = get_texts(transcripts_lexicon, indices)
    ys = [torch.tensor(y) for y in ys]

    return ys
Example #15
0
File: nbest.py Project: k2-fsa/k2
    def intersect(self, lats: Fsa) -> 'Nbest':
        '''Intersect this Nbest object with a lattice and get 1-best
        path from the resulting FsaVec.

        Caution:
          We assume FSAs in `self.fsa` don't have epsilon self-loops.
          We also assume `self.fsa.labels` and `lats.labels` are token IDs.

        Args:
          lats:
            An FsaVec. It can be the return value of
            :func:`whole_lattice_rescoring`.
        Returns:
          Return a new Nbest. This new Nbest shares the same shape with `self`,
          while its `fsa` is the 1-best path from intersecting `self.fsa` and
          `lats.
        '''
        assert self.fsa.device == lats.device, \
                f'{self.fsa.device} vs {lats.device}'
        assert len(lats.shape) == 3, f'{lats.shape}'
        assert lats.arcs.dim0() == self.shape.dim0(), \
                f'{lats.arcs.dim0()} vs {self.shape.dim0()}'

        lats = k2.arc_sort(lats)  # no-op if lats is already arc sorted

        fsas_with_epsilon_loops = k2.add_epsilon_self_loops(self.fsa)

        path_to_seq_map = self.shape.row_ids(1)

        ans_lats = k2.intersect_device(a_fsas=lats,
                                       b_fsas=fsas_with_epsilon_loops,
                                       b_to_a_map=path_to_seq_map,
                                       sorted_match_a=True)

        one_best = k2.shortest_path(ans_lats, use_double_scores=True)

        one_best = k2.remove_epsilon(one_best)

        return Nbest(fsa=one_best, shape=self.shape)
def main():
    parser = get_parser()
    GigaSpeechAsrDataModule.add_arguments(parser)
    args = parser.parse_args()

    model_type = args.model_type
    epoch = args.epoch
    avg = args.avg
    att_rate = args.att_rate
    num_paths = args.num_paths
    use_lm_rescoring = args.use_lm_rescoring
    use_whole_lattice = False
    if use_lm_rescoring and num_paths < 1:
        # It doesn't make sense to use n-best list for rescoring
        # when n is less than 1
        use_whole_lattice = True

    output_beam_size = args.output_beam_size

    suffix = ''
    if args.context_window is not None and args.context_window > 0:
        suffix = f'ac{args.context_window}'
    giga_subset = f'giga{args.subset}'
    exp_dir = Path(
        f'exp-{model_type}-mmi-att-sa-vgg-normlayer-{giga_subset}-{suffix}')

    setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')

    logging.info(f'output_beam_size: {output_beam_size}')

    # load L, G, symbol_table
    lang_dir = Path('data/lang_nosp')
    symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt')
    phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')

    phone_ids = get_phone_symbols(phone_symbol_table)

    phone_ids_with_blank = [0] + phone_ids
    ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))

    logging.debug("About to load model")
    # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N
    # device = torch.device('cuda', 1)
    device = torch.device('cuda')

    if att_rate != 0.0:
        num_decoder_layers = 6
    else:
        num_decoder_layers = 0

    if model_type == "transformer":
        model = Transformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers,
            vgg_frontend=args.vgg_fronted)
    elif model_type == "conformer":
        model = Conformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers,
            vgg_frontend=args.vgg_frontend,
            is_espnet_structure=args.is_espnet_structure)
    elif model_type == "contextnet":
        model = ContextNet(num_features=80, num_classes=len(phone_ids) +
                           1)  # +1 for the blank symbol
    else:
        raise NotImplementedError("Model of type " + str(model_type) +
                                  " is not implemented")

    if avg == 1:
        checkpoint = os.path.join(exp_dir, 'epoch-' + str(epoch - 1) + '.pt')
        load_checkpoint(checkpoint, model)
    else:
        checkpoints = [
            os.path.join(exp_dir, 'epoch-' + str(avg_epoch) + '.pt')
            for avg_epoch in range(epoch - avg, epoch)
        ]
        average_checkpoint(checkpoints, model)

    if args.torchscript:
        logging.info('Applying TorchScript to model...')
        model = torch.jit.script(model)
        ts_path = exp_dir / f'model_ts_epoch{epoch}_avg{avg}.pt'
        logging.info(f'Storing the TorchScripted model in {ts_path}')
        model.save(ts_path)

    model.to(device)
    model.eval()

    if not os.path.exists(lang_dir / 'HLG.pt'):
        logging.debug("Loading L_disambig.fst.txt")
        with open(lang_dir / 'L_disambig.fst.txt') as f:
            L = k2.Fsa.from_openfst(f.read(), acceptor=False)
        logging.debug("Loading G.fst.txt")
        with open(lang_dir / 'G.fst.txt') as f:
            G = k2.Fsa.from_openfst(f.read(), acceptor=False)
        first_phone_disambig_id = find_first_disambig_symbol(
            phone_symbol_table)
        first_word_disambig_id = find_first_disambig_symbol(symbol_table)
        HLG = compile_HLG(L=L,
                          G=G,
                          H=ctc_topo,
                          labels_disambig_id_start=first_phone_disambig_id,
                          aux_labels_disambig_id_start=first_word_disambig_id)
        torch.save(HLG.as_dict(), lang_dir / 'HLG.pt')
    else:
        logging.debug("Loading pre-compiled HLG")
        d = torch.load(lang_dir / 'HLG.pt')
        HLG = k2.Fsa.from_dict(d)

    if use_lm_rescoring:
        if use_whole_lattice:
            logging.info('Rescoring with the whole lattice')
        else:
            logging.info(f'Rescoring with n-best list, n is {num_paths}')
        first_word_disambig_id = find_first_disambig_symbol(symbol_table)
        if not os.path.exists(lang_dir / 'G_4_gram.pt'):
            logging.debug('Loading G_4_gram.fst.txt')
            with open(lang_dir / 'G_4_gram.fst.txt') as f:
                G = k2.Fsa.from_openfst(f.read(), acceptor=False)
                # G.aux_labels is not needed in later computations, so
                # remove it here.
                del G.aux_labels
                # CAUTION(fangjun): The following line is crucial.
                # Arcs entering the back-off state have label equal to #0.
                # We have to change it to 0 here.
                G.labels[G.labels >= first_word_disambig_id] = 0
                G = k2.create_fsa_vec([G]).to(device)
                G = k2.arc_sort(G)
                torch.save(G.as_dict(), lang_dir / 'G_4_gram.pt')
        else:
            logging.debug('Loading pre-compiled G_4_gram.pt')
            d = torch.load(lang_dir / 'G_4_gram.pt')
            G = k2.Fsa.from_dict(d).to(device)

        if use_whole_lattice:
            # Add epsilon self-loops to G as we will compose
            # it with the whole lattice later
            G = k2.add_epsilon_self_loops(G)
            G = k2.arc_sort(G)
            G = G.to(device)
        # G.lm_scores is used to replace HLG.lm_scores during
        # LM rescoring.
        G.lm_scores = G.scores.clone()
    else:
        logging.debug('Decoding without LM rescoring')
        G = None
        if num_paths > 1:
            logging.debug(f'Use n-best list decoding, n is {num_paths}')
        else:
            logging.debug('Use 1-best decoding')

    logging.debug("convert HLG to device")
    HLG = HLG.to(device)
    HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0)
    HLG.requires_grad_(False)

    if not hasattr(HLG, 'lm_scores'):
        HLG.lm_scores = HLG.scores.clone()

    # load dataset
    gigaspeech = GigaSpeechAsrDataModule(args)
    test_sets = ['DEV', 'TEST']
    for test_set, test_dl in zip(
            test_sets,
        [gigaspeech.valid_dataloaders(),
         gigaspeech.test_dataloaders()]):
        logging.info(f'* DECODING: {test_set}')

        test_set_wers = dict()
        results_dict = decode(dataloader=test_dl,
                              model=model,
                              HLG=HLG,
                              symbols=symbol_table,
                              num_paths=num_paths,
                              G=G,
                              use_whole_lattice=use_whole_lattice,
                              output_beam_size=output_beam_size)

        for key, results in results_dict.items():
            recog_path = exp_dir / f'recogs-{test_set}-{key}.txt'
            store_transcripts(path=recog_path, texts=results)
            logging.info(f'The transcripts are stored in {recog_path}')

            ref_path = exp_dir / f'ref-{test_set}.trn'
            hyp_path = exp_dir / f'hyp-{test_set}.trn'
            store_transcripts_for_sclite(ref_path=ref_path,
                                         hyp_path=hyp_path,
                                         texts=results)
            logging.info(
                f'The sclite-format transcripts are stored in {ref_path} and {hyp_path}'
            )
            cmd = f'python3 GigaSpeech/utils/gigaspeech_scoring.py {ref_path} {hyp_path} {exp_dir / "tmp_sclite"}'
            logging.info(cmd)
            try:
                subprocess.run(cmd, check=True, shell=True)
            except subprocess.CalledProcessError:
                logging.error(
                    'Skipping sclite scoring as it failed to run: Is "sclite" registered in your $PATH?"'
                )

            # The following prints out WERs, per-word error statistics and aligned
            # ref/hyp pairs.
            errs_filename = exp_dir / f'errs-{test_set}-{key}.txt'
            with open(errs_filename, 'w') as f:
                wer = write_error_stats(f, f'{test_set}-{key}', results)
                test_set_wers[key] = wer

            logging.info(
                'Wrote detailed error stats to {}'.format(errs_filename))

        test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
        errs_info = exp_dir / f'wer-summary-{test_set}.txt'
        with open(errs_info, 'w') as f:
            print('settings\tWER', file=f)
            for key, val in test_set_wers:
                print('{}\t{}'.format(key, val), file=f)

        s = '\nFor {}, WER of different settings are:\n'.format(test_set)
        note = '\tbest for {}'.format(test_set)
        for key, val in test_set_wers:
            s += '{}\t{}{}\n'.format(key, val, note)
            note = ''
        logging.info(s)
Example #17
0
import k2
s1 = '''
0 1 0 0.1
0 1 1 0.2
1 1 2 0.3
1 2 -1 0.4
2
'''

s2 = '''
0 1 1 1
0 1 2 2
1 2 -1 3
2
'''

a_fsa = k2.Fsa.from_str(s1)
b_fsa = k2.Fsa.from_str(s2)
b_fsa = k2.add_epsilon_self_loops(b_fsa)
c_fsa = k2.intersect(a_fsa, b_fsa, treat_epsilons_specially=False)

a_fsa.draw('a_fsa_intersect3.svg', title='a_fsa')
b_fsa.draw('b_fsa_intersect3.svg', title='b_fsa')
c_fsa.draw('c_fsa_intersect3.svg', title='c_fsa')
Example #18
0
def main():
    parser = get_parser()
    LibriSpeechAsrDataModule.add_arguments(parser)
    args = parser.parse_args()

    model_type = args.model_type
    epoch = args.epoch
    avg = args.avg
    att_rate = args.att_rate
    num_paths = args.num_paths
    use_lm_rescoring = args.use_lm_rescoring
    use_whole_lattice = False
    if use_lm_rescoring and num_paths < 1:
        # It doesn't make sense to use n-best list for rescoring
        # when n is less than 1
        use_whole_lattice = True

    output_beam_size = args.output_beam_size

    exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa-vgg')
    setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')

    logging.info(f'output_beam_size: {output_beam_size}')

    # load L, G, symbol_table
    lang_dir = Path('data/lang_nosp')
    symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt')
    phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')

    phone_ids = get_phone_symbols(phone_symbol_table)
    P = create_bigram_phone_lm(phone_ids)

    phone_ids_with_blank = [0] + phone_ids
    ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))

    logging.debug("About to load model")
    # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N
    # device = torch.device('cuda', 1)
    device = torch.device('cuda')

    if att_rate != 0.0:
        num_decoder_layers = 6
    else:
        num_decoder_layers = 0

    if model_type == "transformer":
        model = Transformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers,
            vgg_frontend=True)
    elif model_type == "conformer":
        model = Conformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers,
            vgg_frontend=True)
    elif model_type == "contextnet":
        model = ContextNet(num_features=80, num_classes=len(phone_ids) +
                           1)  # +1 for the blank symbol
    else:
        raise NotImplementedError("Model of type " + str(model_type) +
                                  " is not implemented")

    model.P_scores = torch.nn.Parameter(P.scores.clone(), requires_grad=False)

    if avg == 1:
        checkpoint = os.path.join(exp_dir, 'epoch-' + str(epoch - 1) + '.pt')
        load_checkpoint(checkpoint, model)
    else:
        checkpoints = [
            os.path.join(exp_dir, 'epoch-' + str(avg_epoch) + '.pt')
            for avg_epoch in range(epoch - avg, epoch)
        ]
        average_checkpoint(checkpoints, model)

    model.to(device)
    model.eval()

    assert P.requires_grad is False
    P.scores = model.P_scores.cpu()
    print_transition_probabilities(P,
                                   phone_symbol_table,
                                   phone_ids,
                                   filename='model_P_scores.txt')

    P.set_scores_stochastic_(model.P_scores)
    print_transition_probabilities(P,
                                   phone_symbol_table,
                                   phone_ids,
                                   filename='P_scores.txt')

    if not os.path.exists(lang_dir / 'HLG.pt'):
        logging.debug("Loading L_disambig.fst.txt")
        with open(lang_dir / 'L_disambig.fst.txt') as f:
            L = k2.Fsa.from_openfst(f.read(), acceptor=False)
        logging.debug("Loading G.fst.txt")
        with open(lang_dir / 'G.fst.txt') as f:
            G = k2.Fsa.from_openfst(f.read(), acceptor=False)
        first_phone_disambig_id = find_first_disambig_symbol(
            phone_symbol_table)
        first_word_disambig_id = find_first_disambig_symbol(symbol_table)
        HLG = compile_HLG(L=L,
                          G=G,
                          H=ctc_topo,
                          labels_disambig_id_start=first_phone_disambig_id,
                          aux_labels_disambig_id_start=first_word_disambig_id)
        torch.save(HLG.as_dict(), lang_dir / 'HLG.pt')
    else:
        logging.debug("Loading pre-compiled HLG")
        d = torch.load(lang_dir / 'HLG.pt')
        HLG = k2.Fsa.from_dict(d)

    if use_lm_rescoring:
        if use_whole_lattice:
            logging.info('Rescoring with the whole lattice')
        else:
            logging.info(f'Rescoring with n-best list, n is {num_paths}')
        first_word_disambig_id = find_first_disambig_symbol(symbol_table)
        if not os.path.exists(lang_dir / 'G_4_gram.pt'):
            logging.debug('Loading G_4_gram.fst.txt')
            with open(lang_dir / 'G_4_gram.fst.txt') as f:
                G = k2.Fsa.from_openfst(f.read(), acceptor=False)
                # G.aux_labels is not needed in later computations, so
                # remove it here.
                del G.aux_labels
                # CAUTION(fangjun): The following line is crucial.
                # Arcs entering the back-off state have label equal to #0.
                # We have to change it to 0 here.
                G.labels[G.labels >= first_word_disambig_id] = 0
                G = k2.create_fsa_vec([G]).to(device)
                G = k2.arc_sort(G)
                torch.save(G.as_dict(), lang_dir / 'G_4_gram.pt')
        else:
            logging.debug('Loading pre-compiled G_4_gram.pt')
            d = torch.load(lang_dir / 'G_4_gram.pt')
            G = k2.Fsa.from_dict(d).to(device)

        if use_whole_lattice:
            # Add epsilon self-loops to G as we will compose
            # it with the whole lattice later
            G = k2.add_epsilon_self_loops(G)
            G = k2.arc_sort(G)
            G = G.to(device)
    else:
        logging.debug('Decoding without LM rescoring')
        G = None

    logging.debug("convert HLG to device")
    HLG = HLG.to(device)
    HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0)
    HLG.requires_grad_(False)

    if not hasattr(HLG, 'lm_scores'):
        HLG.lm_scores = HLG.scores.clone()

    # load dataset
    librispeech = LibriSpeechAsrDataModule(args)
    test_sets = ['test-clean', 'test-other']
    #  test_sets = ['test-other']
    for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
        logging.info(f'* DECODING: {test_set}')

        results = decode(dataloader=test_dl,
                         model=model,
                         device=device,
                         HLG=HLG,
                         symbols=symbol_table,
                         num_paths=num_paths,
                         G=G,
                         use_whole_lattice=use_whole_lattice,
                         output_beam_size=output_beam_size)

        recog_path = exp_dir / f'recogs-{test_set}.txt'
        store_transcripts(path=recog_path, texts=results)
        logging.info(f'The transcripts are stored in {recog_path}')

        # The following prints out WERs, per-word error statistics and aligned
        # ref/hyp pairs.
        errs_filename = exp_dir / f'errs-{test_set}.txt'
        with open(errs_filename, 'w') as f:
            write_error_stats(f, test_set, results)
        logging.info('Wrote detailed error stats to {}'.format(errs_filename))
Example #19
0
def rescore_with_whole_lattice(lats: k2.Fsa,
                               G_with_epsilon_loops: k2.Fsa) -> k2.Fsa:
    '''Use whole lattice to rescore.

    Args:
      lats:
        An FsaVec It can be the output of `k2.intersect_dense_pruned`.
      G_with_epsilon_loops:
        An FsaVec representing the language model (LM). Note that it
        is an FsaVec, but it contains only one Fsa.
    '''
    assert len(lats.shape) == 3
    assert hasattr(lats, 'lm_scores')
    assert G_with_epsilon_loops.shape == (1, None, None)

    device = lats.device
    lats.scores = lats.scores - lats.lm_scores
    # Now, lats.scores contains only am_scores

    # inverted_lats has word IDs as labels.
    # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt
    inverted_lats = k2.invert(lats)
    num_seqs = lats.shape[0]
    inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops(inverted_lats)

    b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
    try:
        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats_with_epsilon_loops,
                                             b_to_a_map,
                                             sorted_match_a=True)
    except RuntimeError as e:
        print(f'Caught exception:\n{e}\n')
        print(f'Number of FSAs: {inverted_lats.shape[0]}')
        print('num_arcs before pruning: ',
              inverted_lats_with_epsilon_loops.arcs.num_elements())

        # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here
        # to avoid OOM. We may need to fine tune it.
        inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True)
        inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops(
            inverted_lats)
        print('num_arcs after pruning: ',
              inverted_lats_with_epsilon_loops.arcs.num_elements())

        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats_with_epsilon_loops,
                                             b_to_a_map,
                                             sorted_match_a=True)

    rescoring_lats = k2.top_sort(k2.connect(
        rescoring_lats.to('cpu'))).to(device)
    inverted_rescoring_lats = k2.invert(rescoring_lats)
    # inverted rescoring_lats has phone IDs as labels
    # and word IDs as aux_labels.

    inverted_rescoring_lats = k2.remove_epsilon_self_loops(
        inverted_rescoring_lats)
    best_paths = k2.shortest_path(inverted_rescoring_lats,
                                  use_double_scores=True)
    return best_paths
Example #20
0
def rescore_with_n_best_list(lats: k2.Fsa, G: k2.Fsa,
                             num_paths: int) -> k2.Fsa:
    '''Decode using n-best list with LM rescoring.

    `lats` is a decoding lattice, which has 3 axes. This function first
    extracts `num_paths` paths from `lats` for each sequence using
    `k2.random_paths`. The `am_scores` of these paths are computed.
    For each path, its `lm_scores` is computed using `G` (which is an LM).
    The final `tot_scores` is the sum of `am_scores` and `lm_scores`.
    The path with the greatest `tot_scores` within a sequence is used
    as the decoding output.

    Args:
      lats:
        An FsaVec. It can be the output of `k2.intersect_dense_pruned`.
      G:
        An FsaVec representing the language model (LM). Note that it
        is an FsaVec, but it contains only one Fsa.
      num_paths:
        It is the size `n` in `n-best` list.
    Returns:
      An FsaVec representing the best decoding path for each sequence
      in the lattice.
    '''
    device = lats.device

    assert len(lats.shape) == 3
    assert hasattr(lats, 'aux_labels')
    assert hasattr(lats, 'lm_scores')

    assert G.shape == (1, None, None)
    assert G.device == device
    assert hasattr(G, 'aux_labels') is False

    # First, extract `num_paths` paths for each sequence.
    # paths is a k2.RaggedInt with axes [seq][path][arc_pos]
    paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)

    # word_seqs is a k2.RaggedInt sharing the same shape as `paths`
    # but it contains word IDs. Note that it also contains 0s and -1s.
    # The last entry in each sublist is -1.
    word_seqs = k2.index(lats.aux_labels, paths)

    # Remove epsilons and -1 from word_seqs
    word_seqs = k2.ragged.remove_values_leq(word_seqs, 0)

    # Remove repeated sequences to avoid redundant computation later.
    #
    # unique_word_seqs is still a k2.RaggedInt with 3 axes [seq][path][word]
    # except that there are no repeated paths with the same word_seq
    # within a seq.
    #
    # num_repeats is also a k2.RaggedInt with 2 axes containing the
    # multiplicities of each path.
    # num_repeats.num_elements() == unique_word_seqs.num_elements()
    #
    # Since k2.ragged.unique_sequences will reorder paths within a seq,
    # `new2old` is a 1-D torch.Tensor mapping from the output path index
    # to the input path index.
    # new2old.numel() == unique_word_seqs.num_elements()
    unique_word_seqs, num_repeats, new2old = k2.ragged.unique_sequences(
        word_seqs, need_num_repeats=True, need_new2old_indexes=True)

    seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0)

    # path_to_seq_map is a 1-D torch.Tensor.
    # path_to_seq_map[i] is the seq to which the i-th path
    # belongs.
    path_to_seq_map = seq_to_path_shape.row_ids(1)

    # Remove the seq axis.
    # Now unique_word_seqs has only two axes [path][word]
    unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0)

    # word_fsas is an FsaVec with axes [path][state][arc]
    word_fsas = k2.linear_fsa(unique_word_seqs)

    word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas)

    am_scores = compute_am_scores(lats, word_fsas_with_epsilon_loops,
                                  path_to_seq_map)

    # Now compute lm_scores
    b_to_a_map = torch.zeros_like(path_to_seq_map)
    lm_path_lats = _intersect_device(G,
                                     word_fsas_with_epsilon_loops,
                                     b_to_a_map=b_to_a_map,
                                     sorted_match_a=True)
    lm_path_lats = k2.top_sort(k2.connect(lm_path_lats.to('cpu'))).to(device)
    lm_scores = lm_path_lats.get_tot_scores(True, True)

    tot_scores = am_scores + lm_scores

    # Remember that we used `k2.ragged.unique_sequences` to remove repeated
    # paths to avoid redundant computation in `k2.intersect_device`.
    # Now we use `num_repeats` to correct the scores for each path.
    #
    # NOTE(fangjun): It is commented out as it leads to a worse WER
    # tot_scores = tot_scores * num_repeats.values()

    # TODO(fangjun): We may need to add `k2.RaggedDouble`
    ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape,
                                       tot_scores.to(torch.float32))
    argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)

    # Use k2.index here since argmax_indexes' dtype is torch.int32
    best_path_indexes = k2.index(new2old, argmax_indexes)

    paths = k2.ragged.remove_axis(paths, 0)

    # best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
    best_paths = k2.index(paths, best_path_indexes)

    # labels is a k2.RaggedInt with 2 axes [path][phone_id]
    # Note that it contains -1s.
    labels = k2.index(lats.labels.contiguous(), best_paths)

    labels = k2.ragged.remove_values_eq(labels, -1)

    # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so
    # aux_labels is also a k2.RaggedInt with 2 axes
    aux_labels = k2.index(lats.aux_labels, best_paths.values())

    best_path_fsas = k2.linear_fsa(labels)
    best_path_fsas.aux_labels = aux_labels

    return best_path_fsas
def nbest_decoding(lats: k2.Fsa, num_paths: int):
    '''
    (Ideas of this function are from Dan)

    It implements something like CTC prefix beam search using n-best lists

    The basic idea is to first extra n-best paths from the given lattice,
    build a word seqs from these paths, and compute the total scores
    of these sequences in the log-semiring. The one with the max score
    is used as the decoding output.
    '''

    # First, extract `num_paths` paths for each sequence.
    # paths is a k2.RaggedInt with axes [seq][path][arc_pos]
    paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)

    # word_seqs is a k2.RaggedInt sharing the same shape as `paths`
    # but it contains word IDs. Note that it also contains 0s and -1s.
    # The last entry in each sublist is -1.

    word_seqs = k2.index(lats.aux_labels, paths)
    # Note: the above operation supports also the case when
    # lats.aux_labels is a ragged tensor. In that case,
    # `remove_axis=True` is used inside the pybind11 binding code,
    # so the resulting `word_seqs` still has 3 axes, like `paths`.
    # The 3 axes are [seq][path][word]

    # Remove epsilons and -1 from word_seqs
    word_seqs = k2.ragged.remove_values_leq(word_seqs, 0)

    # Remove repeated sequences to avoid redundant computation later.
    #
    # Since k2.ragged.unique_sequences will reorder paths within a seq,
    # `new2old` is a 1-D torch.Tensor mapping from the output path index
    # to the input path index.
    # new2old.numel() == unique_word_seqs.num_elements()
    unique_word_seqs, _, new2old = k2.ragged.unique_sequences(
        word_seqs, need_num_repeats=False, need_new2old_indexes=True)
    # Note: unique_word_seqs still has the same axes as word_seqs

    seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0)

    # path_to_seq_map is a 1-D torch.Tensor.
    # path_to_seq_map[i] is the seq to which the i-th path
    # belongs.
    path_to_seq_map = seq_to_path_shape.row_ids(1)

    # Remove the seq axis.
    # Now unique_word_seqs has only two axes [path][word]
    unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0)

    # word_fsas is an FsaVec with axes [path][state][arc]
    word_fsas = k2.linear_fsa(unique_word_seqs)

    word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas)

    # lats has phone IDs as labels and word IDs as aux_labels.
    # inv_lats has word IDs as labels and phone IDs as aux_labels
    inv_lats = k2.invert(lats)
    inv_lats = k2.arc_sort(inv_lats)  # no-op if inv_lats is already arc-sorted

    path_lats = k2.intersect_device(inv_lats,
                                    word_fsas_with_epsilon_loops,
                                    b_to_a_map=path_to_seq_map,
                                    sorted_match_a=True)
    # path_lats has word IDs as labels and phone IDs as aux_labels

    path_lats = k2.top_sort(k2.connect(path_lats.to('cpu')).to(lats.device))

    tot_scores = path_lats.get_tot_scores(True, True)
    # RaggedFloat currently supports float32 only.
    # We may bind Ragged<double> as RaggedDouble if needed.
    ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape,
                                       tot_scores.to(torch.float32))

    argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)

    # Since we invoked `k2.ragged.unique_sequences`, which reorders
    # the index from `paths`, we use `new2old`
    # here to convert argmax_indexes to the indexes into `paths`.
    #
    # Use k2.index here since argmax_indexes' dtype is torch.int32
    best_path_indexes = k2.index(new2old, argmax_indexes)

    paths_2axes = k2.ragged.remove_axis(paths, 0)

    # best_paths is a k2.RaggedInt with 2 axes [path][arc_pos]
    best_paths = k2.index(paths_2axes, best_path_indexes)

    # labels is a k2.RaggedInt with 2 axes [path][phone_id]
    # Note that it contains -1s.
    labels = k2.index(lats.labels.contiguous(), best_paths)

    labels = k2.ragged.remove_values_eq(labels, -1)

    # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so
    # aux_labels is also a k2.RaggedInt with 2 axes
    aux_labels = k2.index(lats.aux_labels, best_paths.values())

    best_path_fsas = k2.linear_fsa(labels)
    best_path_fsas.aux_labels = aux_labels

    return best_path_fsas
Example #22
0
def compose_with_self_loops(base_graph: 'k2.Fsa', aux_graph: 'k2.Fsa') -> 'k2.Fsa':
    """Composition helper function.
    """
    aux_graph_with_self_loops = k2.arc_sort(k2.add_epsilon_self_loops(aux_graph)).to(base_graph.device)
    return k2.compose(base_graph, aux_graph_with_self_loops, treat_epsilons_specially=False, inner_labels="phones",)