Example #1
0
def get_hierarchical_targets(ys: List[List[int]],
                             lexicon: k2.Fsa) -> List[Tensor]:
    """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts).

    Args:
        ys: Word level transcripts.
        lexicon: Its labels are words, while its aux_labels are phones.

    Returns:
        List[Tensor]: Phone level transcripts.

    """

    if lexicon is None:
        return ys
    else:
        L_inv = lexicon

    n_batch = len(ys)
    indices = torch.tensor(range(n_batch))

    transcripts = k2.create_fsa_vec([k2.linear_fsa(x) for x in ys])
    transcripts_lexicon = k2.intersect(transcripts, L_inv)
    transcripts_lexicon = k2.arc_sort(k2.connect(transcripts_lexicon))
    transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon)
    transcripts_lexicon = k2.shortest_path(transcripts_lexicon,
                                           use_double_scores=True)

    ys = get_texts(transcripts_lexicon, indices)
    ys = [torch.tensor(y) for y in ys]

    return ys
Example #2
0
    def build_num_graphs(self, texts: List[str]) -> k2.Fsa:
        '''Convert transcript to an Fsa with the help of lexicon
        and word symbol table.

        Args:
          texts:
            Each element is a transcript containing words separated by spaces.
            For instance, it may be 'HELLO SNOWFALL', which contains
            two words.

        Returns:
          Return an FST (FsaVec) corresponding to the transcript. Its `labels` are
          phone IDs and `aux_labels` are word IDs.
        '''
        word_ids_list = []
        for text in texts:
            word_ids = []
            for word in text.split(' '):
                if word in self.lexicon.words:
                    word_ids.append(self.lexicon.words[word])
                else:
                    word_ids.append(self.oov_id)
            word_ids_list.append(word_ids)

        fsa = k2.linear_fsa(word_ids_list, self.device)
        fsa = k2.add_epsilon_self_loops(fsa)
        assert fsa.device == self.device
        num_graphs = k2.intersect(self.L_inv,
                                  fsa,
                                  treat_epsilons_specially=False).invert_()
        num_graphs = k2.arc_sort(num_graphs)
        return num_graphs
Example #3
0
 def test_fsa_vec(self):
     symbols = [
         [1, 3, 5],
         [2, 6],
         [8, 7, 9],
     ]
     num_symbols = sum([len(s) for s in symbols])
     fsa = k2.linear_fsa(symbols)
     assert len(fsa.shape) == 3
     assert fsa.shape[0] == 3, 'There should be 3 FSAs'
     expected_arcs = [
         # fsa 0
         [0, 1, 1],
         [1, 2, 3],
         [2, 3, 5],
         [3, 4, -1],
         # fsa 1
         [0, 1, 2],
         [1, 2, 6],
         [2, 3, -1],
         # fsa 2
         [0, 1, 8],
         [1, 2, 7],
         [2, 3, 9],
         [3, 4, -1]
     ]
     print(fsa.arcs.values()[:, :-1])
     assert torch.allclose(
         fsa.arcs.values()[:, :-1],  # skip the last field `scores`
         torch.tensor(expected_arcs, dtype=torch.int32))
     assert torch.allclose(
         fsa.scores,
         torch.zeros(num_symbols + len(symbols), dtype=torch.float32))
Example #4
0
    def test_case1(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda'))

        for device in devices:
            # suppose we have four symbols: <blk>, a, b, c, d
            torch_activation = torch.tensor([0.2, 0.2, 0.2, 0.2,
                                             0.2]).to(device)
            k2_activation = torch_activation.detach().clone()

            # (T, N, C)
            torch_activation = torch_activation.reshape(
                1, 1, -1).requires_grad_(True)

            # (N, T, C)
            k2_activation = k2_activation.reshape(1, 1,
                                                  -1).requires_grad_(True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)

            # we have only one sequence and its label is `a`
            targets = torch.tensor([1]).to(device)
            input_lengths = torch.tensor([1]).to(device)
            target_lengths = torch.tensor([1]).to(device)
            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            assert torch.allclose(torch_loss,
                                  torch.tensor([1.6094379425049]).to(device))

            # (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(k2_activation,
                                                           dim=-1)

            supervision_segments = torch.tensor([[0, 0, 1]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo([0, 1, 2, 3, 4]).invert_())
            linear_fsa = k2.linear_fsa([1])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)

            torch_loss.backward()
            (-k2_scores).backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Example #5
0
 def compile_one_and_cache(self, text: str) -> Fsa:
     tokens = (token if token in self.vocab._sym2id else self.oov
               for token in text.split(' '))
     word_ids = [self.vocab.get(token) for token in tokens]
     fsa = k2.linear_fsa(word_ids)
     decoding_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_()
     decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
     return decoding_graph
 def test_single_fsa(self):
     for device in self.devices:
         labels = [2, 0, 0, 0, 5, 8]
         src = k2.linear_fsa(labels, device)
         dst = k2.linear_fsa_with_self_loops(src)
         assert src.device == dst.device
         expected_labels = [0, 2, 0, 5, 0, 8, 0, -1]
         assert dst.labels.tolist() == expected_labels
Example #7
0
    def test_case3(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda'))

        for device in devices:
            # (T, N, C)
            torch_activation = torch.tensor([[
                [-5, -4, -3, -2, -1],
                [-10, -9, -8, -7, -6],
                [-15, -14, -13, -12, -11.],
            ]]).permute(1, 0, 2).to(device).requires_grad_(True)
            torch_activation = torch_activation.to(torch.float32)
            torch_activation.requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)
            # we have only one sequence and its labels are `b,c`
            targets = torch.tensor([2, 3]).to(device)
            input_lengths = torch.tensor([3]).to(device)
            target_lengths = torch.tensor([2]).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            act = k2_activation.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1)

            supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo([0, 1, 2, 3, 4]).invert_())
            linear_fsa = k2.linear_fsa([2, 3])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)
            assert torch.allclose(torch_loss,
                                  torch.tensor([4.938850402832]).to(device))

            torch_loss.backward()
            (-k2_scores).backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Example #8
0
def generate_nbest_list(lats: Fsa, num_paths: int) -> Nbest:
    '''Generate an n-best list from a lattice.

    Args:
      lats:
        The decoding lattice from the first pass after LM rescoring.
        lats is an FsaVec. It can be the return value of
        :func:`whole_lattice_rescoring`
      num_paths:
        Size of n for n-best list. CAUTION: After removing paths
        that represent the same token sequences, the number of paths
        in different sequences may not be equal.
    Return:
      Return an Nbest object. Note the returned FSAs don't have epsilon
      self-loops.
    '''
    assert len(lats.shape) == 3

    # CAUTION: We use `phones` instead of `tokens` here because
    # :func:`compile_HLG` uses `phones`
    #
    # Note: compile_HLG is from k2-fsa/snowfall
    assert hasattr(lats, 'phones')

    assert not hasattr(lats, 'tokens')
    lats.tokens = lats.phones
    # we use tokens instead of phones in the following code

    # First, extract `num_paths` paths for each sequence.
    # paths is a k2.RaggedTensor with axes [seq][path][arc_pos]
    paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)

    # token_seqs is a k2.RaggedTensor sharing the same shape as `paths`
    # but it contains token IDs. Note that it also contains 0s and -1s.
    # The last entry in each sublist is -1.
    # Its axes are [seq][path][token_id]
    token_seqs = k2.ragged.index(lats.tokens, paths)

    # Remove epsilons (0s) and -1 from token_seqs
    token_seqs = token_seqs.remove_values_leq(0)

    # unique_token_seqs is still a k2.RaggedTensor with axes
    # [seq][path]token_id].
    # But then number of paths in each sequence may be different.
    unique_token_seqs, _, _ = token_seqs.unique(need_num_repeats=False,
                                                need_new2old_indexes=False)

    seq_to_path_shape = unique_token_seqs.shape.get_layer(0)

    # Remove the seq axis.
    # Now unique_token_seqs has only two axes [path][token_id]
    unique_token_seqs = unique_token_seqs.remove_axis(0)

    token_fsas = k2.linear_fsa(unique_token_seqs)

    return Nbest(fsa=token_fsas, shape=seq_to_path_shape)
Example #9
0
 def compile_one_and_cache(self, text: str) -> k2.Fsa:
     tokens = (token if token in self.words else self.oov
               for token in text.split(' '))
     word_ids = [self.words[token] for token in tokens]
     fsa = k2.linear_fsa(word_ids)
     decoding_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_()
     decoding_graph = k2.arc_sort(decoding_graph)
     decoding_graph = k2.compose(self.ctc_topo, decoding_graph)
     decoding_graph = k2.connect(decoding_graph)
     return decoding_graph
Example #10
0
    def test_random_case1(self):
        # 1 sequence
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))

        for device in devices:
            T = torch.randint(10, 100, (1,)).item()
            C = torch.randint(20, 30, (1,)).item()
            torch_activation = torch.rand((1, T + 10, C),
                                          dtype=torch.float32,
                                          device=device).requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            # [N, T, C] -> [T, N, C]
            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation.permute(1, 0, 2), dim=-1)

            input_lengths = torch.tensor([T]).to(device)
            target_lengths = torch.randint(1, T, (1,)).to(device)
            targets = torch.randint(1, C - 1,
                                    (target_lengths.item(),)).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')
            k2_log_probs = torch.nn.functional.log_softmax(k2_activation,
                                                           dim=-1)
            supervision_segments = torch.tensor([[0, 0, T]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)
            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo(list(range(C))).invert_())
            linear_fsa = k2.linear_fsa([targets.tolist()])

            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)
            scale = torch.rand_like(torch_loss) * 100
            (torch_loss * scale).sum().backward()
            (-k2_scores * scale).sum().backward()
            assert torch.allclose(torch_activation.grad,
                                  k2_activation.grad,
                                  atol=1e-2)
Example #11
0
 def test_single_fsa(self):
     symbols = [2, 5, 8]
     fsa = k2.linear_fsa(symbols)
     assert len(fsa.shape) == 2
     assert fsa.shape[0] == len(symbols) + 2, 'There should be 5 states'
     assert torch.allclose(
         fsa.scores, torch.zeros(len(symbols) + 1, dtype=torch.float32))
     assert torch.allclose(
         fsa.arcs.values()[:, :-1],  # skip the last field `scores`
         torch.tensor([[0, 1, 2], [1, 2, 5], [2, 3, 8], [3, 4, -1]],
                      dtype=torch.int32))
Example #12
0
    def test_case1(self):
        for device in self.devices:
            # suppose we have four symbols: <blk>, a, b, c, d
            torch_activation = torch.tensor([0.2, 0.2, 0.2, 0.2,
                                             0.2]).to(device)
            k2_activation = torch_activation.detach().clone()

            # (T, N, C)
            torch_activation = torch_activation.reshape(
                1, 1, -1).requires_grad_(True)

            # (N, T, C)
            k2_activation = k2_activation.reshape(1, 1,
                                                  -1).requires_grad_(True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)

            # we have only one sequence and its label is `a`
            targets = torch.tensor([1]).to(device)
            input_lengths = torch.tensor([1]).to(device)
            target_lengths = torch.tensor([1]).to(device)
            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='mean')

            assert torch.allclose(torch_loss,
                                  torch.tensor([1.6094379425049]).to(device))

            # (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(k2_activation,
                                                           dim=-1)

            supervision_segments = torch.tensor([[0, 0, 1]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo = k2.ctc_topo(4)
            linear_fsa = k2.linear_fsa([1])
            decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device)

            k2_loss = k2.ctc_loss(decoding_graph,
                                  dense_fsa_vec,
                                  reduction='mean',
                                  target_lengths=target_lengths)

            assert torch.allclose(torch_loss, k2_loss)

            torch_loss.backward()
            k2_loss.backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Example #13
0
    def test_case3(self):
        for device in self.devices:
            # (T, N, C)
            torch_activation = torch.tensor([[
                [-5, -4, -3, -2, -1],
                [-10, -9, -8, -7, -6],
                [-15, -14, -13, -12, -11.],
            ]]).permute(1, 0, 2).to(device).requires_grad_(True)
            torch_activation = torch_activation.to(torch.float32)
            torch_activation.requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)
            # we have only one sequence and its labels are `b,c`
            targets = torch.tensor([2, 3]).to(device)
            input_lengths = torch.tensor([3]).to(device)
            target_lengths = torch.tensor([2]).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='mean')

            act = k2_activation.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1)

            supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo = k2.ctc_topo(4)
            linear_fsa = k2.linear_fsa([2, 3])
            decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device)

            k2_loss = k2.ctc_loss(decoding_graph,
                                  dense_fsa_vec,
                                  reduction='mean',
                                  target_lengths=target_lengths)

            expected_loss = torch.tensor([4.938850402832],
                                         device=device) / target_lengths

            assert torch.allclose(torch_loss, k2_loss)
            assert torch.allclose(torch_loss, expected_loss)

            torch_loss.backward()
            k2_loss.backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Example #14
0
def create_decoding_graph(texts, L, symbols):
    word_ids_list = []
    for text in texts:
        filter_text = [
            i if i in symbols._sym2id else '<UNK>' for i in text.split(' ')
        ]
        word_ids = [symbols.get(i) for i in filter_text]
        word_ids_list.append(word_ids)
    fsa = k2.linear_fsa(word_ids_list)
    decoding_graph = k2.intersect(fsa, L).invert_()
    decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
    return decoding_graph
 def test_multiple_fsa(self):
     for device in self.devices:
         labels = [[2, 0, 0, 0, 5, 0, 0, 0, 8, 0, 0], [1, 2],
                   [0, 0, 0, 3, 0, 2]]
         src = k2.linear_fsa(labels, device)
         dst = k2.linear_fsa_with_self_loops(src)
         assert src.device == dst.device
         expected_labels0 = [0, 2, 0, 5, 0, 8, 0, -1]
         expected_labels1 = [0, 1, 0, 2, 0, -1]
         expected_labels2 = [0, 3, 0, 2, 0, -1]
         expected_labels = expected_labels0 + expected_labels1 + expected_labels2  # noqa
         assert dst.labels.tolist() == expected_labels
Example #16
0
def create_decoding_graph(texts, graph, symbols):
    fsas = []
    for text in texts:
        filter_text = [
            i if i in symbols._sym2id else '<UNK>' for i in text.split(' ')
        ]
        word_ids = [symbols.get(i) for i in filter_text]
        fsa = k2.linear_fsa(word_ids)
        fsa = k2.arc_sort(fsa)
        decoding_graph = k2.intersect(fsa, graph).invert_()
        decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
        fsas.append(decoding_graph)
    return k2.create_fsa_vec(fsas)
Example #17
0
    def test_case2(self):
        for device in self.devices:
            # (T, N, C)
            torch_activation = torch.arange(1, 16).reshape(1, 3, 5).permute(
                1, 0, 2).to(device)
            torch_activation = torch_activation.to(torch.float32)
            torch_activation.requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)
            # we have only one sequence and its labels are `c,c`
            targets = torch.tensor([3, 3]).to(device)
            input_lengths = torch.tensor([3]).to(device)
            target_lengths = torch.tensor([2]).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            act = k2_activation.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1)

            supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo([0, 1, 2, 3, 4]).invert_())
            linear_fsa = k2.linear_fsa([3, 3])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)
            assert torch.allclose(torch_loss,
                                  torch.tensor([7.355742931366]).to(device))

            torch_loss.backward()
            (-k2_scores).backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Example #18
0
    def compile(self, targets: torch.Tensor,
                target_lengths: torch.Tensor) -> 'k2.Fsa':
        token_ids_list = [
            t[:l].tolist() for t, l in zip(targets, target_lengths)
        ]
        # see https://github.com/k2-fsa/k2/issues/835
        label_graph = k2.linear_fsa(token_ids_list).to(self.device)
        label_graph.aux_labels = label_graph.labels.clone()
        decoding_graphs = compose_with_self_loops(self.base_graph, label_graph)
        decoding_graphs = k2.arc_sort(decoding_graphs).to(self.device)

        # make sure the gradient is not accumulated
        decoding_graphs.requires_grad_(False)
        return decoding_graphs
Example #19
0
    def test_random_case1(self):
        # 1 sequence
        for device in self.devices:
            T = torch.randint(10, 100, (1, )).item()
            C = torch.randint(20, 30, (1, )).item()
            torch_activation = torch.rand((1, T + 10, C),
                                          dtype=torch.float32,
                                          device=device).requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            # [N, T, C] -> [T, N, C]
            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation.permute(1, 0, 2), dim=-1)

            input_lengths = torch.tensor([T]).to(device)
            target_lengths = torch.randint(1, T, (1, )).to(device)
            targets = torch.randint(1, C - 1,
                                    (target_lengths.item(), )).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='mean')
            k2_log_probs = torch.nn.functional.log_softmax(k2_activation,
                                                           dim=-1)
            supervision_segments = torch.tensor([[0, 0, T]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)
            ctc_topo = k2.ctc_topo(C - 1)
            linear_fsa = k2.linear_fsa([targets.tolist()])
            decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device)

            k2_loss = k2.ctc_loss(decoding_graph,
                                  dense_fsa_vec,
                                  reduction='mean',
                                  target_lengths=target_lengths)

            assert torch.allclose(torch_loss, k2_loss)
            scale = torch.rand_like(torch_loss) * 100
            (torch_loss * scale).sum().backward()
            (k2_loss * scale).sum().backward()
            assert torch.allclose(torch_activation.grad,
                                  k2_activation.grad,
                                  atol=1e-2)
Example #20
0
    def test_from_ragged_int_single_fsa(self):
        for device in self.devices:
            ragged_int = k2.RaggedInt('[ [10 20] ]').to(device)
            fsa = k2.linear_fsa(ragged_int)
            assert fsa.shape == (1, None, None)
            assert fsa.device == device
            expected_arcs = torch.tensor([[0, 1, 10], [1, 2, 20], [2, 3, -1]],
                                         dtype=torch.int32,
                                         device=device)
            assert torch.all(
                torch.eq(
                    fsa.arcs.values()[:, :-1],  # skip the last field `scores`
                    expected_arcs))

            assert torch.all(torch.eq(fsa.scores,
                                      torch.zeros_like(fsa.scores)))
Example #21
0
def nbest_am_lm_scores(
    lats: k2.Fsa,
    num_paths: int,
    device: str = "cuda",
    batch_size: int = 500,
):
    """Compute am scores with word_seqs

    Compatible with both ctc_decoding or TLG decoding.
    """
    paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)
    if isinstance(lats.aux_labels, torch.Tensor):
        word_seqs = k2.ragged.index(lats.aux_labels.contiguous(), paths)
    else:
        # '_k2.RaggedInt' object has no attribute 'contiguous'
        word_seqs = lats.aux_labels.index(paths)
        word_seqs = word_seqs.remove_axis(word_seqs.num_axes - 2)

    # With ctc_decoding, word_seqs stores token_ids.
    # With TLG decoding, word_seqs stores word_ids.
    word_seqs = word_seqs.remove_values_leq(0)
    unique_word_seqs, num_repeats, new2old = word_seqs.unique(
        need_num_repeats=True, need_new2old_indexes=True
    )

    seq_to_path_shape = unique_word_seqs.shape.get_layer(0)
    path_to_seq_map = seq_to_path_shape.row_ids(1)
    # used to split final computed tot_scores
    seq_to_path_splits = seq_to_path_shape.row_splits(1)

    unique_word_seqs = unique_word_seqs.remove_axis(0)
    word_fsas = k2.linear_fsa(unique_word_seqs)

    word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas)

    am_scores, lm_scores = compute_am_scores_and_lm_scores(
        lats, word_fsas_with_epsilon_loops, path_to_seq_map, device, batch_size
    )

    token_seqs = k2.ragged.index(lats.labels.contiguous(), paths)
    token_seqs = token_seqs.remove_axis(0)

    token_ids, _ = token_seqs.index(new2old, axis=0)
    token_ids = token_ids.tolist()
    # Now remove repeated tokens and 0s and -1s.
    token_ids = [remove_repeated_and_leq(tokens) for tokens in token_ids]
    return am_scores, lm_scores, token_ids, new2old, path_to_seq_map, seq_to_path_splits
Example #22
0
    def test_single_fsa(self):
        for device in self.devices:
            labels = [2, 5, 8]
            fsa = k2.linear_fsa(labels, device)
            assert fsa.device == device
            assert len(fsa.shape) == 2
            assert fsa.shape[0] == len(labels) + 2, 'There should be 5 states'

            assert torch.all(torch.eq(fsa.scores,
                                      torch.zeros_like(fsa.scores)))

            assert torch.all(
                torch.eq(
                    fsa.arcs.values()[:, :-1],  # skip the last field `scores`
                    torch.tensor([[0, 1, 2], [1, 2, 5], [2, 3, 8], [3, 4, -1]],
                                 dtype=torch.int32,
                                 device=device)))
Example #23
0
 def test_single_fsa(self):
     devices = [torch.device('cpu')]
     if torch.cuda.is_available():
         devices.append(torch.device('cuda', 0))
     for device in devices:
         labels = [2, 5, 8]
         fsa = k2.linear_fsa(labels, device)
         assert fsa.device == device
         assert len(fsa.shape) == 2
         assert fsa.shape[0] == len(labels) + 2, 'There should be 5 states'
         assert torch.allclose(fsa.scores,
                               torch.zeros(len(labels) + 1).to(fsa.scores))
         assert torch.all(
             torch.eq(
                 fsa.arcs.values()[:, :-1],  # skip the last field `scores`
                 torch.tensor([[0, 1, 2], [1, 2, 5], [2, 3, 8], [3, 4, -1]],
                              dtype=torch.int32,
                              device=device)))
Example #24
0
    def test_from_ragged_int_two_fsas(self):
        for device in self.devices:
            ragged = k2.RaggedTensor([[10, 20], [100, 200, 300]]).to(device)
            fsa = k2.linear_fsa(ragged)
            assert fsa.shape == (2, None, None)
            assert fsa.device == device
            expected_arcs = torch.tensor(
                [[0, 1, 10], [1, 2, 20], [2, 3, -1], [0, 1, 100], [1, 2, 200],
                 [2, 3, 300], [3, 4, -1]],
                dtype=torch.int32,
                device=device)
            assert torch.all(
                torch.eq(
                    fsa.arcs.values()[:, :-1],  # skip the last field `scores`
                    expected_arcs))

            assert torch.all(torch.eq(fsa.scores,
                                      torch.zeros_like(fsa.scores)))
Example #25
0
    def compile_one_and_cache(self, text: str) -> k2.Fsa:
        '''Convert transcript to an Fsa with the help of lexicon
        and word symbol table.

        Args:
          text:
            The transcript containing words separated by spaces.

        Returns:
          Return an FST corresponding to the transcript. Its `labels` are
          phone IDs and `aux_labels` are word IDs.
        '''
        tokens = (token if token in self.words else self.oov
                  for token in text.split(' '))
        word_ids = [self.words[token] for token in tokens]
        fsa = k2.linear_fsa(word_ids)
        num_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_()
        num_graph = k2.arc_sort(num_graph)
        return num_graph
Example #26
0
def create_decoding_graph(texts, L, symbols):
    fsas = []
    for text in texts:
        filter_text = [
            i if i in symbols._sym2id else '<UNK>' for i in text.split(' ')
        ]
        word_ids = [symbols.get(i) for i in filter_text]
        fsa = k2.linear_fsa(word_ids)
        print("linear fsa is ", fsa)
        fsa = k2.arc_sort(fsa)
        print("linear fsa, arc-sorted, is ", fsa)
        print("begin")
        print(k2.is_arc_sorted(k2.get_properties(fsa)))
        decoding_graph = k2.intersect(fsa, L).invert_()
        print("linear fsa, composed, is ", fsa)
        print("decoding graph is ", decoding_graph)
        decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
        print("decoding graph with self-loops is ", decoding_graph)
        fsas.append(decoding_graph)
    return k2.create_fsa_vec(fsas)
Example #27
0
def get_hierarchical_targets(ys: List[List[int]],
                             lexicon: k2.Fsa) -> List[Tensor]:
    """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts).

    Args:
        ys: Word level transcripts.
        lexicon: Its labels are words, while its aux_labels are phones.

    Returns:
        List[Tensor]: Phone level transcripts.

    """

    if lexicon is None:
        return ys
    else:
        L_inv = lexicon

    n_batch = len(ys)
    indices = torch.tensor(range(n_batch))
    device = L_inv.device

    transcripts = k2.create_fsa_vec(
        [k2.linear_fsa(x, device=device) for x in ys])
    transcripts_with_self_loops = k2.add_epsilon_self_loops(transcripts)

    transcripts_lexicon = k2.intersect(L_inv,
                                       transcripts_with_self_loops,
                                       treat_epsilons_specially=False)
    # Don't call invert_() above because we want to return phone IDs,
    # which is the `aux_labels` of transcripts_lexicon
    transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon)
    transcripts_lexicon = k2.top_sort(transcripts_lexicon)

    transcripts_lexicon = k2.shortest_path(transcripts_lexicon,
                                           use_double_scores=True)

    ys = get_texts(transcripts_lexicon, indices)
    ys = [torch.tensor(y) for y in ys]

    return ys
Example #28
0
    def test_from_ragged_int_two_fsas(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))
        for device in devices:
            ragged_int = k2.RaggedInt('[ [10 20] [100 200 300] ]').to(device)
            fsa = k2.linear_fsa(ragged_int)
            assert fsa.shape == (2, None, None)
            assert fsa.device == device
            expected_arcs = torch.tensor(
                [[0, 1, 10], [1, 2, 20], [2, 3, -1], [0, 1, 100], [1, 2, 200],
                 [2, 3, 300], [3, 4, -1]],
                dtype=torch.int32,
                device=device)
            assert torch.all(
                torch.eq(
                    fsa.arcs.values()[:, :-1],  # skip the last field `scores`
                    expected_arcs))

            assert torch.all(torch.eq(fsa.scores,
                                      torch.zeros_like(fsa.scores)))
Example #29
0
 def test_fsa_vec(self):
     devices = [torch.device('cpu')]
     if torch.cuda.is_available():
         devices.append(torch.device('cuda', 0))
     for device in devices:
         labels = [
             [1, 3, 5],
             [2, 6],
             [8, 7, 9],
         ]
         num_labels = sum([len(s) for s in labels])
         fsa = k2.linear_fsa(labels, device)
         assert len(fsa.shape) == 3
         assert fsa.device == device
         assert fsa.shape[0] == 3, 'There should be 3 FSAs'
         expected_arcs = [
             # fsa 0
             [0, 1, 1],
             [1, 2, 3],
             [2, 3, 5],
             [3, 4, -1],
             # fsa 1
             [0, 1, 2],
             [1, 2, 6],
             [2, 3, -1],
             # fsa 2
             [0, 1, 8],
             [1, 2, 7],
             [2, 3, 9],
             [3, 4, -1]
         ]
         assert torch.all(
             torch.eq(
                 fsa.arcs.values()[:, :-1],  # skip the last field `scores`
                 torch.tensor(expected_arcs,
                              dtype=torch.int32,
                              device=device)))
         assert torch.allclose(
             fsa.scores,
             torch.zeros(num_labels + len(labels)).to(fsa.scores))
Example #30
0
    def test_fsa_vec(self):
        for device in self.devices:
            labels = [
                [1, 3, 5],
                [2, 6],
                [8, 7, 9],
            ]
            fsa = k2.linear_fsa(labels, device)
            assert len(fsa.shape) == 3
            assert fsa.device == device
            assert fsa.shape[0] == 3, 'There should be 3 FSAs'
            expected_arcs = [
                # fsa 0
                [0, 1, 1],
                [1, 2, 3],
                [2, 3, 5],
                [3, 4, -1],
                # fsa 1
                [0, 1, 2],
                [1, 2, 6],
                [2, 3, -1],
                # fsa 2
                [0, 1, 8],
                [1, 2, 7],
                [2, 3, 9],
                [3, 4, -1]
            ]
            assert torch.all(
                torch.eq(
                    fsa.arcs.values()[:, :-1],  # skip the last field `scores`
                    torch.tensor(expected_arcs,
                                 dtype=torch.int32,
                                 device=device)))

            assert torch.all(torch.eq(fsa.scores,
                                      torch.zeros_like(fsa.scores)))