Ejemplo n.º 1
0
    def test_no_repeated(self):
        # standard ctc topo and modified ctc topo
        # should be equivalent if there are no
        # repeated neighboring symbols in the transcript
        max_token = 3
        standard = k2.ctc_topo(max_token, modified=False)
        modified = k2.ctc_topo(max_token, modified=True)
        transcript = k2.linear_fsa([1, 2, 3])
        standard_graph = k2.compose(standard, transcript)
        modified_graph = k2.compose(modified, transcript)

        input1 = k2.linear_fsa([1, 1, 1, 0, 0, 2, 2, 3, 3])
        input2 = k2.linear_fsa([1, 1, 0, 0, 2, 2, 0, 3, 3])
        inputs = [input1, input2]
        for i in inputs:
            lattice1 = k2.intersect(standard_graph,
                                    i,
                                    treat_epsilons_specially=False)
            lattice2 = k2.intersect(modified_graph,
                                    i,
                                    treat_epsilons_specially=False)
            lattice1 = k2.connect(lattice1)
            lattice2 = k2.connect(lattice2)

            aux_labels1 = lattice1.aux_labels[lattice1.aux_labels != 0]
            aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0]
            aux_labels1 = aux_labels1[:-1]  # remove -1
            aux_labels2 = aux_labels2[:-1]
            assert torch.all(torch.eq(aux_labels1, aux_labels2))
            assert torch.all(torch.eq(aux_labels2, torch.tensor([1, 2, 3])))
Ejemplo n.º 2
0
    def compile(self, texts: Iterable[str],
                P: k2.Fsa) -> Tuple[k2.Fsa, k2.Fsa]:
        '''Create numerator and denominator graphs from transcripts
        and the bigram phone LM.

        Args:
          texts:
            A list of transcripts. Within a transcript, words are
            separated by spaces.
          P:
            The bigram phone LM created by :func:`create_bigram_phone_lm`.
        Returns:
          A tuple (num_graph, den_graph), where

            - `num_graph` is the numerator graph. It is an FsaVec with
              shape `(len(texts), None, None)`.

            - `den_graph` is the denominator graph. It is an FsaVec with the same
              shape of the `num_graph`.
        '''
        assert P.is_cpu()

        ctc_topo_P = k2.intersect(self.ctc_topo, P).invert_()
        ctc_topo_P = k2.connect(ctc_topo_P)

        num_graphs = k2.create_fsa_vec(
            [self.compile_one_and_cache(text) for text in texts])

        num = k2.compose(ctc_topo_P, num_graphs)
        num = k2.connect(num)
        num = k2.arc_sort(num)

        den = k2.create_fsa_vec([ctc_topo_P.detach()] * len(texts))

        return num, den
Ejemplo n.º 3
0
 def compile_one_and_cache(self, text: str) -> k2.Fsa:
     tokens = (token if token in self.words else self.oov
               for token in text.split(' '))
     word_ids = [self.words[token] for token in tokens]
     fsa = k2.linear_fsa(word_ids)
     decoding_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_()
     decoding_graph = k2.arc_sort(decoding_graph)
     decoding_graph = k2.compose(self.ctc_topo, decoding_graph)
     decoding_graph = k2.connect(decoding_graph)
     return decoding_graph
Ejemplo n.º 4
0
    def test_with_repeated(self):
        max_token = 2
        standard = k2.ctc_topo(max_token, modified=False)
        modified = k2.ctc_topo(max_token, modified=True)
        transcript = k2.linear_fsa([1, 2, 2])
        standard_graph = k2.compose(standard, transcript)
        modified_graph = k2.compose(modified, transcript)

        # There is a blank separating 2 in the input
        # so standard and modified ctc topo should be equivalent
        input = k2.linear_fsa([1, 1, 2, 2, 0, 2, 2, 0, 0])
        lattice1 = k2.intersect(standard_graph,
                                input,
                                treat_epsilons_specially=False)
        lattice2 = k2.intersect(modified_graph,
                                input,
                                treat_epsilons_specially=False)
        lattice1 = k2.connect(lattice1)
        lattice2 = k2.connect(lattice2)

        aux_labels1 = lattice1.aux_labels[lattice1.aux_labels != 0]
        aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0]
        aux_labels1 = aux_labels1[:-1]  # remove -1
        aux_labels2 = aux_labels2[:-1]
        assert torch.all(torch.eq(aux_labels1, aux_labels2))
        assert torch.all(torch.eq(aux_labels1, torch.tensor([1, 2, 2])))

        # There are no blanks separating 2 in the input.
        # The standard ctc topo requires that there must be a blank
        # separating 2, so lattice1 in the following is empty
        input = k2.linear_fsa([1, 1, 2, 2, 0, 0])
        lattice1 = k2.intersect(standard_graph,
                                input,
                                treat_epsilons_specially=False)
        lattice2 = k2.intersect(modified_graph,
                                input,
                                treat_epsilons_specially=False)
        lattice1 = k2.connect(lattice1)
        lattice2 = k2.connect(lattice2)
        assert lattice1.num_arcs == 0

        # Since there are two 2s in the input and there are also two 2s
        # in the transcript, the final output contains only one path.
        # If there were more than two 2s in the input, the output
        # would contain more than one path
        aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0]
        aux_labels2 = aux_labels2[:-1]
        assert torch.all(torch.eq(aux_labels1, torch.tensor([1, 2, 2])))
Ejemplo n.º 5
0
    def test_compose(self):
        s = '''
            0 1 11 1 1.0
            0 2 12 2 2.5
            1 3 -1 -1 0
            2 3 -1 -1 2.5
            3
        '''
        a_fsa = k2.Fsa.from_str(s, num_aux_labels=1).requires_grad_(True)

        s = '''
            0 1 1 1 1.0
            0 2 2 3 3.0
            1 2 3 2 2.5
            2 3 -1 -1 2.0
            3
        '''
        b_fsa = k2.Fsa.from_str(s, num_aux_labels=1).requires_grad_(True)

        ans = k2.compose(a_fsa, b_fsa, inner_labels='inner')
        ans = k2.connect(ans)

        ans = k2.create_fsa_vec([ans])

        scores = ans.get_tot_scores(log_semiring=True, use_double_scores=False)
        # The reference values for `scores`, `a_fsa.grad` and `b_fsa.grad`
        # are computed using GTN.
        # See https://bit.ly/3heLAJq
        assert scores.item() == 10
        scores.backward()
        assert torch.allclose(a_fsa.grad, torch.tensor([0., 1., 0., 1.]))
        assert torch.allclose(b_fsa.grad, torch.tensor([0., 1., 0., 1.]))
Ejemplo n.º 6
0
def get_hierarchical_targets(ys: List[List[int]],
                             lexicon: k2.Fsa) -> List[Tensor]:
    """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts).

    Args:
        ys: Word level transcripts.
        lexicon: Its labels are words, while its aux_labels are phones.

    Returns:
        List[Tensor]: Phone level transcripts.

    """

    if lexicon is None:
        return ys
    else:
        L_inv = lexicon

    n_batch = len(ys)
    indices = torch.tensor(range(n_batch))

    transcripts = k2.create_fsa_vec([k2.linear_fsa(x) for x in ys])
    transcripts_lexicon = k2.intersect(transcripts, L_inv)
    transcripts_lexicon = k2.arc_sort(k2.connect(transcripts_lexicon))
    transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon)
    transcripts_lexicon = k2.shortest_path(transcripts_lexicon,
                                           use_double_scores=True)

    ys = get_texts(transcripts_lexicon, indices)
    ys = [torch.tensor(y) for y in ys]

    return ys
Ejemplo n.º 7
0
    def test_case1(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda'))

        for device in devices:
            # suppose we have four symbols: <blk>, a, b, c, d
            torch_activation = torch.tensor([0.2, 0.2, 0.2, 0.2,
                                             0.2]).to(device)
            k2_activation = torch_activation.detach().clone()

            # (T, N, C)
            torch_activation = torch_activation.reshape(
                1, 1, -1).requires_grad_(True)

            # (N, T, C)
            k2_activation = k2_activation.reshape(1, 1,
                                                  -1).requires_grad_(True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)

            # we have only one sequence and its label is `a`
            targets = torch.tensor([1]).to(device)
            input_lengths = torch.tensor([1]).to(device)
            target_lengths = torch.tensor([1]).to(device)
            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            assert torch.allclose(torch_loss,
                                  torch.tensor([1.6094379425049]).to(device))

            # (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(k2_activation,
                                                           dim=-1)

            supervision_segments = torch.tensor([[0, 0, 1]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo([0, 1, 2, 3, 4]).invert_())
            linear_fsa = k2.linear_fsa([1])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)

            torch_loss.backward()
            (-k2_scores).backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Ejemplo n.º 8
0
    def test_ragged_aux_labels(self):
        s1 = '''
            0 1 1 0.1
            0 2 5 0.6
            1 2 3 0.3
            2 3 3 0.5
            2 4 2 0.6
            3 5 -1 0.7
            4 5 -1 0.8
            5
        '''

        s2 = '''
            0 0 2 1 1
            0 1 4 3 2
            0 1 6 2 2
            0 2 -1 -1 0
            1 1 2 5 3
            1 2 -1 -1 4
            2
        '''
        # https://git.io/JqNok
        fsa1 = k2.Fsa.from_str(s1)
        fsa1.aux_labels = k2.RaggedInt('[[2] [2 4] [5] [3] [2] [-1] [-1]]')

        # https://git.io/JqNaJ
        fsa2 = k2.Fsa.from_str(s2, num_aux_labels=1)

        # https://git.io/JqNon
        ans = k2.connect(k2.compose(fsa1, fsa2, inner_labels='phones'))

        assert torch.all(torch.eq(ans.labels, torch.tensor([5, 0, 2, -1])))
        assert torch.all(torch.eq(ans.phones, torch.tensor([2, 4, 2, -1])))
        assert str(ans.aux_labels) == str(k2.RaggedInt('[[1] [3] [5] [-1]]'))
Ejemplo n.º 9
0
    def test_compose_inner_labels(self):
        s1 = '''
            0 1 1 2 0.1
            0 2 0 2 0.2
            1 3 3 5 0.3
            2 3 5 4 0.4
            3 4 3 3 0.5
            3 5 2 2 0.6
            4 6 -1 -1 0.7
            5 6 -1 -1 0.8
            6
        '''

        s2 = '''
            0 0 2 1 1
            0 1 4 3 2
            0 1 6 2 2
            0 2 -1 -1 0
            1 1 2 5 3
            1 2 -1 -1 4
            2
        '''

        # https://git.io/JqN2j
        fsa1 = k2.Fsa.from_str(s1, num_aux_labels=1)

        # https://git.io/JqNaJ
        fsa2 = k2.Fsa.from_str(s2, num_aux_labels=1)

        # https://git.io/JqNaT
        ans = k2.connect(k2.compose(fsa1, fsa2, inner_labels='phones'))

        assert torch.all(torch.eq(ans.labels, torch.tensor([0, 5, 2, -1])))
        assert torch.all(torch.eq(ans.phones, torch.tensor([2, 4, 2, -1])))
        assert torch.all(torch.eq(ans.aux_labels, torch.tensor([1, 3, 5, -1])))
Ejemplo n.º 10
0
 def compile_one_and_cache(self, text: str) -> Fsa:
     tokens = (token if token in self.vocab._sym2id else self.oov
               for token in text.split(' '))
     word_ids = [self.vocab.get(token) for token in tokens]
     fsa = k2.linear_fsa(word_ids)
     decoding_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_()
     decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
     return decoding_graph
Ejemplo n.º 11
0
def compile_LG(L: Fsa, G: Fsa, labels_disambig_id_start: int,
               aux_labels_disambig_id_start: int) -> Fsa:
    """
    Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``.
    Involves arc sorting, intersection, determinization, removal of disambiguation symbols
    and adding epsilon self-loops.

    Args:
        L:
            An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols``
                and words as ``aux_symbols``.
        G:
            An ``Fsa`` that represents the language model (G), i.e. it's an acceptor
            with words as ``symbols``.
        labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            phonetic alphabet.
        aux_labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            words vocabulary.
    :return:
    """
    L_inv = k2.arc_sort(L.invert_())
    G = k2.arc_sort(G)
    logging.debug("Intersecting L and G")
    LG = k2.intersect(L_inv, G)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting L*G")
    LG = k2.connect(LG).invert_()
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Determinizing L*G")
    LG = k2.determinize(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting det(L*G)")
    LG = k2.connect(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Removing disambiguation symbols on L*G")
    LG.labels[LG.labels >= labels_disambig_id_start] = 0
    LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0
    LG = k2.add_epsilon_self_loops(LG)
    LG = k2.arc_sort(LG)
    logging.debug(
        f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}'
    )
    return LG
Ejemplo n.º 12
0
    def test_case3(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda'))

        for device in devices:
            # (T, N, C)
            torch_activation = torch.tensor([[
                [-5, -4, -3, -2, -1],
                [-10, -9, -8, -7, -6],
                [-15, -14, -13, -12, -11.],
            ]]).permute(1, 0, 2).to(device).requires_grad_(True)
            torch_activation = torch_activation.to(torch.float32)
            torch_activation.requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)
            # we have only one sequence and its labels are `b,c`
            targets = torch.tensor([2, 3]).to(device)
            input_lengths = torch.tensor([3]).to(device)
            target_lengths = torch.tensor([2]).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            act = k2_activation.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1)

            supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo([0, 1, 2, 3, 4]).invert_())
            linear_fsa = k2.linear_fsa([2, 3])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)
            assert torch.allclose(torch_loss,
                                  torch.tensor([4.938850402832]).to(device))

            torch_loss.backward()
            (-k2_scores).backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Ejemplo n.º 13
0
    def test_random_case1(self):
        # 1 sequence
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))

        for device in devices:
            T = torch.randint(10, 100, (1,)).item()
            C = torch.randint(20, 30, (1,)).item()
            torch_activation = torch.rand((1, T + 10, C),
                                          dtype=torch.float32,
                                          device=device).requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            # [N, T, C] -> [T, N, C]
            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation.permute(1, 0, 2), dim=-1)

            input_lengths = torch.tensor([T]).to(device)
            target_lengths = torch.randint(1, T, (1,)).to(device)
            targets = torch.randint(1, C - 1,
                                    (target_lengths.item(),)).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')
            k2_log_probs = torch.nn.functional.log_softmax(k2_activation,
                                                           dim=-1)
            supervision_segments = torch.tensor([[0, 0, T]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)
            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo(list(range(C))).invert_())
            linear_fsa = k2.linear_fsa([targets.tolist()])

            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)
            scale = torch.rand_like(torch_loss) * 100
            (torch_loss * scale).sum().backward()
            (-k2_scores * scale).sum().backward()
            assert torch.allclose(torch_activation.grad,
                                  k2_activation.grad,
                                  atol=1e-2)
Ejemplo n.º 14
0
def compute_am_scores(lats: k2.Fsa, word_fsas_with_epsilon_loops: k2.Fsa,
                      path_to_seq_map: torch.Tensor) -> torch.Tensor:
    '''Compute AM scores of n-best lists (represented as word_fsas).

    Args:
      lats:
        An FsaVec, which is the output of `k2.intersect_dense_pruned`.
        It must have the attribute `lm_scores`.
      word_fsas_with_epsilon_loops:
        An FsaVec representing a n-best list. Note that it has been processed
        by `k2.add_epsilon_self_loops`.
      path_to_seq_map:
        A 1-D torch.Tensor with dtype torch.int32. path_to_seq_map[i] indicates
        which sequence the i-th Fsa in word_fsas_with_epsilon_loops belongs to.
        path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0().
    Returns:
      Return a 1-D torch.Tensor containing the AM scores of each path.
      `ans.numel() == word_fsas_with_epsilon_loops.shape[0]`
    '''
    device = lats.device
    assert len(lats.shape) == 3
    assert hasattr(lats, 'lm_scores')

    # k2.compose() currently does not support b_to_a_map. To void
    # replicating `lats`, we use k2.intersect_device here.
    #
    # lats has phone IDs as `labels` and word IDs as aux_labels, so we
    # need to invert it here.
    inverted_lats = k2.invert(lats)

    # Now the `labels` of inverted_lats are word IDs (a 1-D torch.Tensor)
    # and its `aux_labels` are phone IDs ( a k2.RaggedInt with 2 axes)

    # Remove its `aux_labels` since it is not needed in the
    # following computation
    del inverted_lats.aux_labels
    inverted_lats = k2.arc_sort(inverted_lats)

    am_path_lats = _intersect_device(inverted_lats,
                                     word_fsas_with_epsilon_loops,
                                     b_to_a_map=path_to_seq_map,
                                     sorted_match_a=True)

    # NOTE: `k2.connect` and `k2.top_sort` support only CPU at present
    am_path_lats = k2.top_sort(k2.connect(am_path_lats.to('cpu'))).to(device)

    # The `scores` of every arc consists of `am_scores` and `lm_scores`
    am_path_lats.scores = am_path_lats.scores - am_path_lats.lm_scores

    am_scores = am_path_lats.get_tot_scores(True, True)

    return am_scores
Ejemplo n.º 15
0
def _generate_fsa_vec(min_num_fsas: int = 20,
                      max_num_fsas: int = 21,
                      acyclic: bool = True,
                      max_symbol: int = 20,
                      min_num_arcs: int = 10,
                      max_num_arcs: int = 15) -> k2.Fsa:
    fsa = k2.random_fsa_vec(min_num_fsas, max_num_fsas, acyclic, min_num_arcs,
                            max_num_arcs)
    fsa = k2.connect(fsa)
    while True:
        success = True
        for i in range(fsa.shape[0]):
            if fsa[i].shape[0] == 0:
                success = False
                break
        if success:
            break
        else:
            fsa = k2.random_fsa_vec(min_num_fsas, max_num_fsas, acyclic,
                                    min_num_arcs, max_num_arcs)
            fsa = k2.connect(fsa)
    return fsa
Ejemplo n.º 16
0
    def test_case2(self):
        for device in self.devices:
            # (T, N, C)
            torch_activation = torch.arange(1, 16).reshape(1, 3, 5).permute(
                1, 0, 2).to(device)
            torch_activation = torch_activation.to(torch.float32)
            torch_activation.requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)
            # we have only one sequence and its labels are `c,c`
            targets = torch.tensor([3, 3]).to(device)
            input_lengths = torch.tensor([3]).to(device)
            target_lengths = torch.tensor([2]).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            act = k2_activation.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1)

            supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo([0, 1, 2, 3, 4]).invert_())
            linear_fsa = k2.linear_fsa([3, 3])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)
            assert torch.allclose(torch_loss,
                                  torch.tensor([7.355742931366]).to(device))

            torch_loss.backward()
            (-k2_scores).backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Ejemplo n.º 17
0
    def compile_one_and_cache(self, text: str) -> k2.Fsa:
        '''Convert transcript to an Fsa with the help of lexicon
        and word symbol table.

        Args:
          text:
            The transcript containing words separated by spaces.

        Returns:
          Return an FST corresponding to the transcript. Its `labels` are
          phone IDs and `aux_labels` are word IDs.
        '''
        tokens = (token if token in self.words else self.oov
                  for token in text.split(' '))
        word_ids = [self.words[token] for token in tokens]
        fsa = k2.linear_fsa(word_ids)
        num_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_()
        num_graph = k2.arc_sort(num_graph)
        return num_graph
Ejemplo n.º 18
0
    def test(self):
        s = '''
            0 1 1 0.1
            0 2 2 0.2
            1 4 -1 0.3
            3 4 -1 0.4
            4
        '''
        fsa = k2.Fsa.from_str(s)
        fsa.requires_grad_(True)
        expected_str = '\n'.join(['0 1 1 0.1', '1 2 -1 0.3', '2'])
        connected_fsa = k2.connect(fsa)
        actual_str = k2.to_str_simple(connected_fsa)
        assert actual_str.strip() == expected_str

        loss = connected_fsa.scores.sum()
        loss.backward()
        assert torch.allclose(fsa.scores.grad,
                              torch.tensor([1, 0, 1, 0], dtype=torch.float32))
Ejemplo n.º 19
0
    def test_composition_equivalence(self):
        index = _generate_fsa_vec()
        index = k2.arc_sort(k2.connect(k2.remove_epsilon(index)))

        src = _generate_fsa_vec()

        replace = k2.replace_fsa(src, index, 1)
        replace = k2.top_sort(replace)

        f_fsa = _construct_f(src)
        f_fsa = k2.arc_sort(f_fsa)
        intersect = k2.intersect(index, f_fsa, treat_epsilons_specially=True)
        intersect = k2.invert(intersect)
        intersect = k2.top_sort(intersect)
        delattr(intersect, 'aux_labels')

        assert k2.is_rand_equivalent(replace,
                                     intersect,
                                     log_semiring=True,
                                     delta=1e-3)
Ejemplo n.º 20
0
 def test_random(self):
     while True:
         fsa = k2.random_fsa(max_symbol=20,
                             min_num_arcs=50,
                             max_num_arcs=500)
         fsa = k2.arc_sort(k2.connect(k2.remove_epsilon(fsa)))
         prob = fsa.properties
         # we need non-deterministic fsa
         if not prob & k2.fsa_properties.ARC_SORTED_AND_DETERMINISTIC:
             break
     log_semiring = False
     # test weight pushing tropical
     dest_max = k2.determinize(
         fsa, k2.DeterminizeWeightPushingType.kTropicalWeightPushing)
     self.assertTrue(
         k2.is_rand_equivalent(fsa, dest_max, log_semiring, delta=1e-3))
     # test weight pushing log
     dest_log = k2.determinize(
         fsa, k2.DeterminizeWeightPushingType.kLogWeightPushing)
     self.assertTrue(
         k2.is_rand_equivalent(fsa, dest_log, log_semiring, delta=1e-3))
Ejemplo n.º 21
0
    def test_compose(self):
        s = '''
            0 1 11 1 1.0
            0 2 12 2 2.5
            1 3 -1 -1 0
            2 3 -1 -1 2.5
            3
        '''
        a_fsa = k2.Fsa.from_str(s).requires_grad_(True)

        s = '''
            0 1 1 1 1.0
            0 2 2 3 3.0
            1 2 3 2 2.5
            2 3 -1 -1 2.0
            3
        '''
        b_fsa = k2.Fsa.from_str(s).requires_grad_(True)

        ans = k2.compose(a_fsa, b_fsa, inner_labels='inner')
        ans = k2.connect(ans)

        # Convert a single FSA to a FsaVec.
        # It will retain `requires_grad_` of `ans`.
        ans.__dict__['arcs'] = _k2.create_fsa_vec([ans.arcs])

        scores = k2.get_tot_scores(ans,
                                   log_semiring=True,
                                   use_double_scores=False)
        # The reference values for `scores`, `a_fsa.grad` and `b_fsa.grad`
        # are computed using GTN.
        # See https://bit.ly/3heLAJq
        assert scores.item() == 10
        scores.backward()
        assert torch.allclose(a_fsa.grad, torch.tensor([0., 1., 0., 1.]))
        assert torch.allclose(b_fsa.grad, torch.tensor([0., 1., 0., 1.]))
        print(ans)
Ejemplo n.º 22
0
def compile_HLG(L: Fsa, G: Fsa, H: Fsa, labels_disambig_id_start: int,
                aux_labels_disambig_id_start: int) -> Fsa:
    """
    Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``.
    Involves arc sorting, intersection, determinization, removal of disambiguation symbols
    and adding epsilon self-loops.

    Args:
        L:
            An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols``
                and words as ``aux_symbols``.
        G:
            An ``Fsa`` that represents the language model (G), i.e. it's an acceptor
            with words as ``symbols``.
        H:  An ``Fsa`` that represents a specific topology used to convert the network
            outputs to a sequence of phones.
            Typically, it's a CTC topology fst, in which when 0 appears on the left
            side, it represents the blank symbol; when it appears on the right side,
            it indicates an epsilon.
        labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            phonetic alphabet.
        aux_labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            words vocabulary.
    :return:
    """
    L = k2.arc_sort(L)
    G = k2.arc_sort(G)
    logging.info("Intersecting L and G")
    LG = k2.compose(L, G)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting L*G")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Determinizing L*G")
    LG = k2.determinize(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting det(L*G)")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Removing disambiguation symbols on L*G")
    LG.labels[LG.labels >= labels_disambig_id_start] = 0
    if isinstance(LG.aux_labels, torch.Tensor):
        LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0
    else:
        LG.aux_labels.values()[
            LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0
    logging.info("Removing epsilons")
    LG = k2.remove_epsilon(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting rm-eps(det(L*G))")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)

    logging.info("Arc sorting LG")
    LG = k2.arc_sort(LG)

    logging.info("Composing ctc_topo LG")
    HLG = k2.compose(H, LG, inner_labels='phones')

    logging.info("Connecting LG")
    HLG = k2.connect(HLG)

    logging.info("Arc sorting LG")
    HLG = k2.arc_sort(HLG)
    logging.info(
        f'LG is arc sorted: {(HLG.properties & k2.fsa_properties.ARC_SORTED) != 0}'
    )

    # Attach a new attribute `lm_scores` so that we can recover
    # the `am_scores` later.
    # The scores on an arc consists of two parts:
    #  scores = am_scores + lm_scores
    # NOTE: we assume that both kinds of scores are in log-space.
    HLG.lm_scores = HLG.scores.clone()
    return HLG
Ejemplo n.º 23
0
def compile_LG(L: Fsa, G: Fsa, ctc_topo: Fsa, labels_disambig_id_start: int,
               aux_labels_disambig_id_start: int) -> Fsa:
    """
    Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``.
    Involves arc sorting, intersection, determinization, removal of disambiguation symbols
    and adding epsilon self-loops.

    Args:
        L:
            An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols``
                and words as ``aux_symbols``.
        G:
            An ``Fsa`` that represents the language model (G), i.e. it's an acceptor
            with words as ``symbols``.
        ctc_topo:  CTC topology fst, in which when 0 appears on the left side, it represents
                   the blank symbol; when it appears on the right side,
                   it indicates an epsilon.
        labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            phonetic alphabet.
        aux_labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            words vocabulary.
    :return:
    """
    L = k2.arc_sort(L)
    G = k2.arc_sort(G)
    logging.info("Intersecting L and G")
    LG = k2.compose(L, G)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting L*G")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Determinizing L*G")
    LG = k2.determinize(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting det(L*G)")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Removing disambiguation symbols on L*G")
    LG.labels[LG.labels >= labels_disambig_id_start] = 0
    if isinstance(LG.aux_labels, torch.Tensor):
        LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0
    else:
        LG.aux_labels.values()[
            LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0
    logging.info("Removing epsilons")
    LG = k2.remove_epsilon(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting rm-eps(det(L*G))")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)

    logging.info("Arc sorting LG")
    LG = k2.arc_sort(LG)

    logging.info("Composing ctc_topo LG")
    LG = k2.compose(ctc_topo, LG, inner_labels='phones')

    logging.info("Connecting LG")
    LG = k2.connect(LG)

    logging.info("Arc sorting LG")
    LG = k2.arc_sort(LG)
    logging.info(
        f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}'
    )
    return LG
Ejemplo n.º 24
0
    def test(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda'))

        for device in devices:
            for use_identity_map, sorted_match_a in [(True, True),
                                                     (False, True),
                                                     (True, False),
                                                     (False, False)]:
                # recognizes (0|1)(0|2)
                s1 = '''
                    0 1 0 0.1
                    0 1 1 0.2
                    1 2 0 0.4
                    1 2 2 0.3
                    2 3 -1 0.5
                    3
                '''

                # recognizes 02*
                s2 = '''
                    0 1 0 1
                    1 1 2 2
                    1 2 -1 3
                    2
                '''

                # recognizes 1*0
                s3 = '''
                    0 0 1 10
                    0 1 0 20
                    1 2 -1 30
                    2
                '''
                a_fsa = k2.Fsa.from_str(s1).to(device)
                b_fsa_1 = k2.Fsa.from_str(s2).to(device)
                b_fsa_2 = k2.Fsa.from_str(s3).to(device)

                a_fsa.requires_grad_(True)
                b_fsa_1.requires_grad_(True)
                b_fsa_2.requires_grad_(True)

                b_fsas = k2.create_fsa_vec([b_fsa_1, b_fsa_2])
                if use_identity_map:
                    a_fsas = k2.create_fsa_vec([a_fsa, a_fsa])
                    b_to_a_map = torch.tensor([0, 1],
                                              dtype=torch.int32).to(device)
                else:
                    a_fsas = k2.create_fsa_vec([a_fsa])
                    b_to_a_map = torch.tensor([0, 0],
                                              dtype=torch.int32).to(device)

                c_fsas = k2.intersect_device(a_fsas, b_fsas, b_to_a_map,
                                             sorted_match_a)
                assert c_fsas.shape == (2, None, None)
                c_fsas = k2.connect(c_fsas.to('cpu'))
                # c_fsas[0] recognizes: 02
                # c_fsas[1] recognizes: 10

                actual_str_0 = k2.to_str(c_fsas[0])
                expected_str_0 = '\n'.join(
                    ['0 1 0 1.1', '1 2 2 2.3', '2 3 -1 3.5', '3'])
                assert actual_str_0.strip() == expected_str_0

                actual_str_1 = k2.to_str(c_fsas[1])
                expected_str_1 = '\n'.join(
                    ['0 1 1 10.2', '1 2 0 20.4', '2 3 -1 30.5', '3'])
                assert actual_str_1.strip() == expected_str_1

                loss = c_fsas.scores.sum()
                (-loss).backward()
                assert torch.allclose(
                    a_fsa.grad,
                    torch.tensor([-1, -1, -1, -1, -2]).to(a_fsa.grad))
                assert torch.allclose(
                    b_fsa_1.grad,
                    torch.tensor([-1, -1, -1]).to(b_fsa_1.grad))
                assert torch.allclose(
                    b_fsa_2.grad,
                    torch.tensor([-1, -1, -1]).to(b_fsa_2.grad))
Ejemplo n.º 25
0
    def test_random_case2(self):
        # 2 sequences
        for device in self.devices:
            T1 = torch.randint(10, 200, (1, )).item()
            T2 = torch.randint(9, 100, (1, )).item()
            C = torch.randint(20, 30, (1, )).item()
            if T1 < T2:
                T1, T2 = T2, T1

            torch_activation_1 = torch.rand((T1, C),
                                            dtype=torch.float32,
                                            device=device).requires_grad_(True)
            torch_activation_2 = torch.rand((T2, C),
                                            dtype=torch.float32,
                                            device=device).requires_grad_(True)

            k2_activation_1 = torch_activation_1.detach().clone(
            ).requires_grad_(True)
            k2_activation_2 = torch_activation_2.detach().clone(
            ).requires_grad_(True)

            # [T, N, C]
            torch_activations = torch.nn.utils.rnn.pad_sequence(
                [torch_activation_1, torch_activation_2],
                batch_first=False,
                padding_value=0)

            # [N, T, C]
            k2_activations = torch.nn.utils.rnn.pad_sequence(
                [k2_activation_1, k2_activation_2],
                batch_first=True,
                padding_value=0)

            target_length1 = torch.randint(1, T1, (1, )).item()
            target_length2 = torch.randint(1, T2, (1, )).item()

            target_lengths = torch.tensor([target_length1,
                                           target_length2]).to(device)
            targets = torch.randint(1, C - 1,
                                    (target_lengths.sum(), )).to(device)

            # [T, N, C]
            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activations, dim=-1)
            input_lengths = torch.tensor([T1, T2]).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            assert T1 >= T2
            supervision_segments = torch.tensor([[0, 0, T1], [1, 0, T2]],
                                                dtype=torch.int32)
            k2_log_probs = torch.nn.functional.log_softmax(k2_activations,
                                                           dim=-1)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)
            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo(list(range(C))).invert_())
            linear_fsa = k2.linear_fsa([
                targets[:target_length1].tolist(),
                targets[target_length1:].tolist()
            ])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)
            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)
            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)
            scale = torch.rand_like(torch_loss) * 100
            (torch_loss * scale).sum().backward()
            (-k2_scores * scale).sum().backward()
            assert torch.allclose(torch_activation_1.grad,
                                  k2_activation_1.grad,
                                  atol=1e-2)
            assert torch.allclose(torch_activation_2.grad,
                                  k2_activation_2.grad,
                                  atol=1e-2)
Ejemplo n.º 26
0
    def test_case4(self):
        for device in self.devices:
            # put case3, case2 and case1 into a batch
            torch_activation_1 = torch.tensor(
                [[0., 0., 0., 0., 0.]]).to(device).requires_grad_(True)

            torch_activation_2 = torch.arange(1, 16).reshape(3, 5).to(
                torch.float32).to(device).requires_grad_(True)

            torch_activation_3 = torch.tensor([
                [-5, -4, -3, -2, -1],
                [-10, -9, -8, -7, -6],
                [-15, -14, -13, -12, -11.],
            ]).to(device).requires_grad_(True)

            k2_activation_1 = torch_activation_1.detach().clone(
            ).requires_grad_(True)
            k2_activation_2 = torch_activation_2.detach().clone(
            ).requires_grad_(True)
            k2_activation_3 = torch_activation_3.detach().clone(
            ).requires_grad_(True)

            # [T, N, C]
            torch_activations = torch.nn.utils.rnn.pad_sequence(
                [torch_activation_3, torch_activation_2, torch_activation_1],
                batch_first=False,
                padding_value=0)

            # [N, T, C]
            k2_activations = torch.nn.utils.rnn.pad_sequence(
                [k2_activation_3, k2_activation_2, k2_activation_1],
                batch_first=True,
                padding_value=0)

            # [[b,c], [c,c], [a]]
            targets = torch.tensor([2, 3, 3, 3, 1]).to(device)
            input_lengths = torch.tensor([3, 3, 1]).to(device)
            target_lengths = torch.tensor([2, 2, 1]).to(device)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activations, dim=-1)  # (T, N, C)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            assert torch.allclose(
                torch_loss,
                torch.tensor([4.938850402832, 7.355742931366,
                              1.6094379425049]).to(device))

            k2_log_probs = torch.nn.functional.log_softmax(k2_activations,
                                                           dim=-1)
            supervision_segments = torch.tensor(
                [[0, 0, 3], [1, 0, 3], [2, 0, 1]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo([0, 1, 2, 3, 4]).invert_())
            # [ [b, c], [c, c], [a]]
            linear_fsa = k2.linear_fsa([[2, 3], [3, 3], [1]])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)

            scale = torch.tensor([1., -2, 3.5]).to(device)
            (torch_loss * scale).sum().backward()
            (-k2_scores * scale).sum().backward()
            assert torch.allclose(torch_activation_1.grad,
                                  k2_activation_1.grad)
            assert torch.allclose(torch_activation_2.grad,
                                  k2_activation_2.grad)
            assert torch.allclose(torch_activation_3.grad,
                                  k2_activation_3.grad)
def nbest_decoding(lats: k2.Fsa, num_paths: int):
    '''
    (Ideas of this function are from Dan)

    It implements something like CTC prefix beam search using n-best lists

    The basic idea is to first extra n-best paths from the given lattice,
    build a word seqs from these paths, and compute the total scores
    of these sequences in the log-semiring. The one with the max score
    is used as the decoding output.
    '''

    # First, extract `num_paths` paths for each sequence.
    # paths is a k2.RaggedInt with axes [seq][path][arc_pos]
    paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)

    # word_seqs is a k2.RaggedInt sharing the same shape as `paths`
    # but it contains word IDs. Note that it also contains 0s and -1s.
    # The last entry in each sublist is -1.

    word_seqs = k2.index(lats.aux_labels, paths)
    # Note: the above operation supports also the case when
    # lats.aux_labels is a ragged tensor. In that case,
    # `remove_axis=True` is used inside the pybind11 binding code,
    # so the resulting `word_seqs` still has 3 axes, like `paths`.
    # The 3 axes are [seq][path][word]

    # Remove epsilons and -1 from word_seqs
    word_seqs = k2.ragged.remove_values_leq(word_seqs, 0)

    # Remove repeated sequences to avoid redundant computation later.
    #
    # Since k2.ragged.unique_sequences will reorder paths within a seq,
    # `new2old` is a 1-D torch.Tensor mapping from the output path index
    # to the input path index.
    # new2old.numel() == unique_word_seqs.num_elements()
    unique_word_seqs, _, new2old = k2.ragged.unique_sequences(
        word_seqs, need_num_repeats=False, need_new2old_indexes=True)
    # Note: unique_word_seqs still has the same axes as word_seqs

    seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0)

    # path_to_seq_map is a 1-D torch.Tensor.
    # path_to_seq_map[i] is the seq to which the i-th path
    # belongs.
    path_to_seq_map = seq_to_path_shape.row_ids(1)

    # Remove the seq axis.
    # Now unique_word_seqs has only two axes [path][word]
    unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0)

    # word_fsas is an FsaVec with axes [path][state][arc]
    word_fsas = k2.linear_fsa(unique_word_seqs)

    word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas)

    # lats has phone IDs as labels and word IDs as aux_labels.
    # inv_lats has word IDs as labels and phone IDs as aux_labels
    inv_lats = k2.invert(lats)
    inv_lats = k2.arc_sort(inv_lats)  # no-op if inv_lats is already arc-sorted

    path_lats = k2.intersect_device(inv_lats,
                                    word_fsas_with_epsilon_loops,
                                    b_to_a_map=path_to_seq_map,
                                    sorted_match_a=True)
    # path_lats has word IDs as labels and phone IDs as aux_labels

    path_lats = k2.top_sort(k2.connect(path_lats.to('cpu')).to(lats.device))

    tot_scores = path_lats.get_tot_scores(True, True)
    # RaggedFloat currently supports float32 only.
    # We may bind Ragged<double> as RaggedDouble if needed.
    ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape,
                                       tot_scores.to(torch.float32))

    argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)

    # Since we invoked `k2.ragged.unique_sequences`, which reorders
    # the index from `paths`, we use `new2old`
    # here to convert argmax_indexes to the indexes into `paths`.
    #
    # Use k2.index here since argmax_indexes' dtype is torch.int32
    best_path_indexes = k2.index(new2old, argmax_indexes)

    paths_2axes = k2.ragged.remove_axis(paths, 0)

    # best_paths is a k2.RaggedInt with 2 axes [path][arc_pos]
    best_paths = k2.index(paths_2axes, best_path_indexes)

    # labels is a k2.RaggedInt with 2 axes [path][phone_id]
    # Note that it contains -1s.
    labels = k2.index(lats.labels.contiguous(), best_paths)

    labels = k2.ragged.remove_values_eq(labels, -1)

    # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so
    # aux_labels is also a k2.RaggedInt with 2 axes
    aux_labels = k2.index(lats.aux_labels, best_paths.values())

    best_path_fsas = k2.linear_fsa(labels)
    best_path_fsas.aux_labels = aux_labels

    return best_path_fsas
Ejemplo n.º 28
0
def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa,
                               lm_scale_list: List[float]
                              ) -> Dict[str, k2.Fsa]:
    '''Use whole lattice to rescore.

    Args:
      lats:
        An FsaVec It can be the output of `k2.intersect_dense_pruned`.
      G_with_epsilon_loops:
        An FsaVec representing the language model (LM). Note that it
        is an FsaVec, but it contains only one Fsa.
      lm_scale_list:
        A list containing lm_scale values.
    Returns:
      A dict of FsaVec, whose key is a lm_scale and the value represents the
      best decoding path for each sequence in the lattice.
    '''
    assert len(lats.shape) == 3
    assert hasattr(lats, 'lm_scores')
    assert G_with_epsilon_loops.shape == (1, None, None)

    device = lats.device
    lats.scores = lats.scores - lats.lm_scores
    # We will use lm_scores from G, so remove lats.lm_scores here
    del lats.lm_scores
    assert hasattr(lats, 'lm_scores') is False

    #  lats.scores = scores / lm_scale
    # Now, lats.scores contains only am_scores

    # inverted_lats has word IDs as labels.
    # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt
    inverted_lats = k2.invert(lats)
    num_seqs = lats.shape[0]

    b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
    try:
        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats,
                                             b_to_a_map,
                                             sorted_match_a=True)
    except RuntimeError as e:
        print(f'Caught exception:\n{e}\n')
        print(f'Number of FSAs: {inverted_lats.shape[0]}')
        print('num_arcs before pruning: ', inverted_lats.arcs.num_elements())

        # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here
        # to avoid OOM. We may need to fine tune it.
        inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True)
        print('num_arcs after pruning: ', inverted_lats.arcs.num_elements())

        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats,
                                             b_to_a_map,
                                             sorted_match_a=True)

    rescoring_lats = k2.top_sort(k2.connect(rescoring_lats.to('cpu')).to(device))

    # inv_lats has phone IDs as labels
    # and word IDs as aux_labels.
    inv_lats = k2.invert(rescoring_lats)

    ans = dict()
    #
    # The following implements
    # scores = (scores - lm_scores)/lm_scale + lm_scores
    #        = scores/lm_scale + lm_scores*(1 - 1/lm_scale)
    #
    saved_scores = inv_lats.scores.clone()
    for lm_scale in lm_scale_list:
        am_scores = saved_scores - inv_lats.lm_scores
        am_scores /= lm_scale
        inv_lats.scores = am_scores + inv_lats.lm_scores

        best_paths = k2.shortest_path(inv_lats, use_double_scores=True)
        key = f'lm_scale_{lm_scale}'
        ans[key] = best_paths
    return ans
Ejemplo n.º 29
0
def rescore_with_whole_lattice(lats: k2.Fsa,
                               G_with_epsilon_loops: k2.Fsa) -> k2.Fsa:
    '''Use whole lattice to rescore.

    Args:
      lats:
        An FsaVec It can be the output of `k2.intersect_dense_pruned`.
      G_with_epsilon_loops:
        An FsaVec representing the language model (LM). Note that it
        is an FsaVec, but it contains only one Fsa.
    '''
    assert len(lats.shape) == 3
    assert hasattr(lats, 'lm_scores')
    assert G_with_epsilon_loops.shape == (1, None, None)

    device = lats.device
    lats.scores = lats.scores - lats.lm_scores
    # Now, lats.scores contains only am_scores

    # inverted_lats has word IDs as labels.
    # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt
    inverted_lats = k2.invert(lats)
    num_seqs = lats.shape[0]
    inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops(inverted_lats)

    b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
    try:
        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats_with_epsilon_loops,
                                             b_to_a_map,
                                             sorted_match_a=True)
    except RuntimeError as e:
        print(f'Caught exception:\n{e}\n')
        print(f'Number of FSAs: {inverted_lats.shape[0]}')
        print('num_arcs before pruning: ',
              inverted_lats_with_epsilon_loops.arcs.num_elements())

        # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here
        # to avoid OOM. We may need to fine tune it.
        inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True)
        inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops(
            inverted_lats)
        print('num_arcs after pruning: ',
              inverted_lats_with_epsilon_loops.arcs.num_elements())

        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats_with_epsilon_loops,
                                             b_to_a_map,
                                             sorted_match_a=True)

    rescoring_lats = k2.top_sort(k2.connect(
        rescoring_lats.to('cpu'))).to(device)
    inverted_rescoring_lats = k2.invert(rescoring_lats)
    # inverted rescoring_lats has phone IDs as labels
    # and word IDs as aux_labels.

    inverted_rescoring_lats = k2.remove_epsilon_self_loops(
        inverted_rescoring_lats)
    best_paths = k2.shortest_path(inverted_rescoring_lats,
                                  use_double_scores=True)
    return best_paths
Ejemplo n.º 30
0
def rescore_with_n_best_list(lats: k2.Fsa, G: k2.Fsa,
                             num_paths: int) -> k2.Fsa:
    '''Decode using n-best list with LM rescoring.

    `lats` is a decoding lattice, which has 3 axes. This function first
    extracts `num_paths` paths from `lats` for each sequence using
    `k2.random_paths`. The `am_scores` of these paths are computed.
    For each path, its `lm_scores` is computed using `G` (which is an LM).
    The final `tot_scores` is the sum of `am_scores` and `lm_scores`.
    The path with the greatest `tot_scores` within a sequence is used
    as the decoding output.

    Args:
      lats:
        An FsaVec. It can be the output of `k2.intersect_dense_pruned`.
      G:
        An FsaVec representing the language model (LM). Note that it
        is an FsaVec, but it contains only one Fsa.
      num_paths:
        It is the size `n` in `n-best` list.
    Returns:
      An FsaVec representing the best decoding path for each sequence
      in the lattice.
    '''
    device = lats.device

    assert len(lats.shape) == 3
    assert hasattr(lats, 'aux_labels')
    assert hasattr(lats, 'lm_scores')

    assert G.shape == (1, None, None)
    assert G.device == device
    assert hasattr(G, 'aux_labels') is False

    # First, extract `num_paths` paths for each sequence.
    # paths is a k2.RaggedInt with axes [seq][path][arc_pos]
    paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)

    # word_seqs is a k2.RaggedInt sharing the same shape as `paths`
    # but it contains word IDs. Note that it also contains 0s and -1s.
    # The last entry in each sublist is -1.
    word_seqs = k2.index(lats.aux_labels, paths)

    # Remove epsilons and -1 from word_seqs
    word_seqs = k2.ragged.remove_values_leq(word_seqs, 0)

    # Remove repeated sequences to avoid redundant computation later.
    #
    # unique_word_seqs is still a k2.RaggedInt with 3 axes [seq][path][word]
    # except that there are no repeated paths with the same word_seq
    # within a seq.
    #
    # num_repeats is also a k2.RaggedInt with 2 axes containing the
    # multiplicities of each path.
    # num_repeats.num_elements() == unique_word_seqs.num_elements()
    #
    # Since k2.ragged.unique_sequences will reorder paths within a seq,
    # `new2old` is a 1-D torch.Tensor mapping from the output path index
    # to the input path index.
    # new2old.numel() == unique_word_seqs.num_elements()
    unique_word_seqs, num_repeats, new2old = k2.ragged.unique_sequences(
        word_seqs, need_num_repeats=True, need_new2old_indexes=True)

    seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0)

    # path_to_seq_map is a 1-D torch.Tensor.
    # path_to_seq_map[i] is the seq to which the i-th path
    # belongs.
    path_to_seq_map = seq_to_path_shape.row_ids(1)

    # Remove the seq axis.
    # Now unique_word_seqs has only two axes [path][word]
    unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0)

    # word_fsas is an FsaVec with axes [path][state][arc]
    word_fsas = k2.linear_fsa(unique_word_seqs)

    word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas)

    am_scores = compute_am_scores(lats, word_fsas_with_epsilon_loops,
                                  path_to_seq_map)

    # Now compute lm_scores
    b_to_a_map = torch.zeros_like(path_to_seq_map)
    lm_path_lats = _intersect_device(G,
                                     word_fsas_with_epsilon_loops,
                                     b_to_a_map=b_to_a_map,
                                     sorted_match_a=True)
    lm_path_lats = k2.top_sort(k2.connect(lm_path_lats.to('cpu'))).to(device)
    lm_scores = lm_path_lats.get_tot_scores(True, True)

    tot_scores = am_scores + lm_scores

    # Remember that we used `k2.ragged.unique_sequences` to remove repeated
    # paths to avoid redundant computation in `k2.intersect_device`.
    # Now we use `num_repeats` to correct the scores for each path.
    #
    # NOTE(fangjun): It is commented out as it leads to a worse WER
    # tot_scores = tot_scores * num_repeats.values()

    # TODO(fangjun): We may need to add `k2.RaggedDouble`
    ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape,
                                       tot_scores.to(torch.float32))
    argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)

    # Use k2.index here since argmax_indexes' dtype is torch.int32
    best_path_indexes = k2.index(new2old, argmax_indexes)

    paths = k2.ragged.remove_axis(paths, 0)

    # best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
    best_paths = k2.index(paths, best_path_indexes)

    # labels is a k2.RaggedInt with 2 axes [path][phone_id]
    # Note that it contains -1s.
    labels = k2.index(lats.labels.contiguous(), best_paths)

    labels = k2.ragged.remove_values_eq(labels, -1)

    # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so
    # aux_labels is also a k2.RaggedInt with 2 axes
    aux_labels = k2.index(lats.aux_labels, best_paths.values())

    best_path_fsas = k2.linear_fsa(labels)
    best_path_fsas.aux_labels = aux_labels

    return best_path_fsas