Exemple #1
0
    def __init__(self,
                 L_inv: k2.Fsa,
                 phones: k2.SymbolTable,
                 words: k2.SymbolTable,
                 oov: str = '<UNK>'):
        '''
        Args:
          L_inv:
            Its labels are words, while its aux_labels are phones.
        phones:
          The phone symbol table.
        words:
          The word symbol table.
        oov:
          Out of vocabulary word.
        '''
        if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0:
            L_inv = k2.arc_sort(L_inv)

        assert oov in words

        self.L_inv = L_inv
        self.phones = phones
        self.words = words
        self.oov = oov
        phone_ids = get_phone_symbols(phones)
        phone_ids_with_blank = [0] + phone_ids
        self.ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))
Exemple #2
0
    def __init__(self,
                 L_inv: k2.Fsa,
                 phones: k2.SymbolTable,
                 words: k2.SymbolTable,
                 oov: str = '<UNK>'):
        '''
        Args:
          L_inv:
            Its labels are words, while its aux_labels are phones.
        phones:
          The phone symbol table.
        words:
          The word symbol table.
        oov:
          Out of vocabulary word.
        '''
        if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0:
            L_inv = k2.arc_sort(L_inv)

        assert oov in words

        self.L_inv = L_inv
        self.phones = phones
        self.words = words
        self.oov = oov
        ctc_topo = build_ctc_topo(list(phones._id2sym.keys()))
        self.ctc_topo = k2.arc_sort(ctc_topo)
Exemple #3
0
def intersect_with_self_loops(base_graph: 'k2.Fsa', aux_graph: 'k2.Fsa') -> 'k2.Fsa':
    """Intersection helper function.
    """
    assert hasattr(base_graph, "aux_labels")
    assert not hasattr(aux_graph, "aux_labels")
    aux_graph_with_self_loops = k2.arc_sort(k2.add_epsilon_self_loops(aux_graph)).to(base_graph.device)
    result = k2.intersect(k2.arc_sort(base_graph), aux_graph_with_self_loops, treat_epsilons_specially=False,)
    setattr(result, "phones", result.labels)
    return result
Exemple #4
0
    def compile(self,
                texts: Iterable[str],
                P: k2.Fsa,
                replicate_den: bool = True) -> Tuple[k2.Fsa, k2.Fsa]:
        '''Create numerator and denominator graphs from transcripts
        and the bigram phone LM.

        Args:
          texts:
            A list of transcripts. Within a transcript, words are
            separated by spaces.
          P:
            The bigram phone LM created by :func:`create_bigram_phone_lm`.
          replicate_den:
            If True, the returned den_graph is replicated to match the number
            of FSAs in the returned num_graph; if False, the returned den_graph
            contains only a single FSA
        Returns:
          A tuple (num_graph, den_graph), where

            - `num_graph` is the numerator graph. It is an FsaVec with
              shape `(len(texts), None, None)`.

            - `den_graph` is the denominator graph. It is an FsaVec with the same
              shape of the `num_graph` if replicate_den is True; otherwise, it
              is an FsaVec containing only a single FSA.
        '''
        assert P.device == self.device
        P_with_self_loops = k2.add_epsilon_self_loops(P)

        ctc_topo_P = k2.intersect(self.ctc_topo_inv,
                                  P_with_self_loops,
                                  treat_epsilons_specially=False).invert()

        ctc_topo_P = k2.arc_sort(ctc_topo_P)

        num_graphs = self.build_num_graphs(texts)
        num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops(
            num_graphs)

        num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops)

        num = k2.compose(ctc_topo_P,
                         num_graphs_with_self_loops,
                         treat_epsilons_specially=False)
        num = k2.arc_sort(num)

        ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()])
        if replicate_den:
            indexes = torch.zeros(len(texts),
                                  dtype=torch.int32,
                                  device=self.device)
            den = k2.index_fsa(ctc_topo_P_vec, indexes)
        else:
            den = ctc_topo_P_vec

        return num, den
Exemple #5
0
    def test_empty_fsa(self):
        array_size = k2.IntArray2Size(0, 0)
        fsa = k2.Fsa.create_fsa_with_size(array_size)
        arc_map = k2.IntArray1.create_array_with_size(fsa.size2)
        k2.arc_sort(fsa, arc_map)
        self.assertTrue(k2.is_empty(fsa))
        self.assertTrue(arc_map.empty())

        # test without arc_map
        k2.arc_sort(fsa)
        self.assertTrue(k2.is_empty(fsa))
Exemple #6
0
    def compile(self, texts: Iterable[str],
                P: k2.Fsa) -> Tuple[k2.Fsa, k2.Fsa, k2.Fsa]:
        '''Create numerator and denominator graphs from transcripts
        and the bigram phone LM.

        Args:
          texts:
            A list of transcripts. Within a transcript, words are
            separated by spaces.
          P:
            The bigram phone LM created by :func:`create_bigram_phone_lm`.
        Returns:
          A tuple (num_graph, den_graph, decoding_graph), where

            - `num_graph` is the numerator graph. It is an FsaVec with
              shape `(len(texts), None, None)`.
              It is the result of compose(ctc_topo, P, L, transcript)

            - `den_graph` is the denominator graph. It is an FsaVec with the same
              shape of the `num_graph`.
              It is the result of compose(ctc_topo, P).

            - decoding_graph: It is the result of compose(ctc_topo, L_disambig, G)
              Note that it is a single Fsa, not an FsaVec.
        '''
        assert P.device == self.device
        P_with_self_loops = k2.add_epsilon_self_loops(P)

        ctc_topo_P = k2.intersect(self.ctc_topo_inv,
                                  P_with_self_loops,
                                  treat_epsilons_specially=False).invert()
        ctc_topo_P = k2.arc_sort(ctc_topo_P)

        num_graphs = self.build_num_graphs(texts)

        num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops(
            num_graphs)

        num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops)

        num = k2.compose(ctc_topo_P,
                         num_graphs_with_self_loops,
                         treat_epsilons_specially=False,
                         inner_labels='phones')
        num = k2.arc_sort(num)

        ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()])
        indexes = torch.zeros(len(texts),
                              dtype=torch.int32,
                              device=self.device)
        den = k2.index_fsa(ctc_topo_P_vec, indexes)

        return num, den, self.decoding_graph
Exemple #7
0
    def test_treat_epsilon_specially_true(self):
        # this version works only on CPU and requires
        # arc-sorted inputs
        # a_fsa recognizes `(1|3)?2*`
        s1 = '''
            0 1 3 0.0
            0 1 1 0.2
            0 1 0 0.1
            1 1 2 0.3
            1 2 -1 0.4
            2
        '''
        a_fsa = k2.Fsa.from_str(s1)
        a_fsa.requires_grad_(True)

        # b_fsa recognizes `1|2|5`
        s2 = '''
            0 1 5 0
            0 1 1 1
            0 1 2 2
            1 2 -1 3
            2
        '''
        b_fsa = k2.Fsa.from_str(s2)
        b_fsa.requires_grad_(True)

        # fsa recognizes 1|2
        fsa = k2.intersect(k2.arc_sort(a_fsa), k2.arc_sort(b_fsa))
        assert len(fsa.shape) == 2
        actual_str = k2.to_str_simple(fsa)
        expected_str = '\n'.join(
            ['0 1 0 0.1', '0 2 1 1.2', '1 2 2 2.3', '2 3 -1 3.4', '3'])
        assert actual_str.strip() == expected_str

        loss = fsa.scores.sum()
        (-loss).backward()
        # arc 1, 2, 3, and 4 of a_fsa are kept in the final intersected FSA
        assert torch.allclose(a_fsa.grad,
                              torch.tensor([0, -1, -1, -1, -1]).to(a_fsa.grad))

        # arc 1, 2, and 3 of b_fsa are kept in the final intersected FSA
        assert torch.allclose(b_fsa.grad,
                              torch.tensor([0, -1, -1, -1]).to(b_fsa.grad))

        # if any of the input FSA is an FsaVec,
        # the outupt FSA is also an FsaVec.
        a_fsa.scores.grad = None
        b_fsa.scores.grad = None
        a_fsa = k2.create_fsa_vec([a_fsa])
        fsa = k2.intersect(k2.arc_sort(a_fsa), k2.arc_sort(b_fsa))
        assert len(fsa.shape) == 3
Exemple #8
0
    def __init__(self,
                 lexicon: Lexicon,
                 P: k2.Fsa,
                 device: torch.device,
                 oov: str = '<UNK>'):
        '''
        Args:
          L_inv:
            Its labels are words, while its aux_labels are phones.
          P:
            A phone bigram LM if the pronunciations in the lexicon are in phones;
            a word piece bigram if the pronunciations in the lexicon are word pieces.
          phones:
            The phone symbol table.
          words:
            The word symbol table.
          oov:
            Out of vocabulary word.
        '''
        self.lexicon = lexicon
        L_inv = self.lexicon.L_inv.to(device)
        P = P.to(device)

        if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0:
            L_inv = k2.arc_sort(L_inv)

        assert L_inv.requires_grad is False

        assert oov in self.lexicon.words

        self.L_inv = L_inv
        self.oov_id = self.lexicon.words[oov]
        self.oov = oov
        self.device = device

        phone_symbols = get_phone_symbols(self.lexicon.phones)
        phone_symbols_with_blank = [0] + phone_symbols

        ctc_topo = build_ctc_topo(phone_symbols_with_blank).to(device)
        assert ctc_topo.requires_grad is False

        ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())

        P_with_self_loops = k2.add_epsilon_self_loops(P)

        ctc_topo_P = k2.intersect(ctc_topo_inv,
                                  P_with_self_loops,
                                  treat_epsilons_specially=False).invert()

        self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
Exemple #9
0
    def __init__(
            self,
            num_classes: int,
            topo_type: str = "default",
            topo_with_self_loops: bool = True,
            device: torch.device = torch.device("cpu"),
    ):
        # use k2 import guard
        k2_import_guard()

        self.topo_type = topo_type
        self.device = device
        self.base_graph = k2.arc_sort(
            build_topo(topo_type, list(range(num_classes)),
                       topo_with_self_loops)).to(self.device)
        self.ctc_topo_inv = k2.arc_sort(self.base_graph.invert())
Exemple #10
0
def build_ctc_topo2(phones: List[int]):
    # See https://github.com/k2-fsa/k2/issues/746#issuecomment-856421616
    assert 0 in phones, 'We assume 0 is the ID of the blank symbol'
    phones = phones.copy()
    phones.remove(0)

    num_phones = len(phones)

    start = 0
    final = num_phones + 1

    arcs = []
    arcs.append([start, start, 0, 0, 0])
    arcs.append([start, final, -1, -1, 0])
    arcs.append([final])
    for i, p in enumerate(phones):
        i += 1
        arcs.append([start, start, p, p, 0])

        arcs.append([start, i, p, p, 0])
        arcs.append([i, i, p, 0, 0])

        arcs.append([i, start, p, 0, 0])

    arcs = sorted(arcs, key=lambda arc: arc[0])
    arcs = [[str(i) for i in arc] for arc in arcs]
    arcs = [' '.join(arc) for arc in arcs]
    arcs = '\n'.join(arcs)
    ctc_topo = k2.Fsa.from_str(arcs, False)
    return k2.arc_sort(ctc_topo)
Exemple #11
0
def get_hierarchical_targets(ys: List[List[int]],
                             lexicon: k2.Fsa) -> List[Tensor]:
    """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts).

    Args:
        ys: Word level transcripts.
        lexicon: Its labels are words, while its aux_labels are phones.

    Returns:
        List[Tensor]: Phone level transcripts.

    """

    if lexicon is None:
        return ys
    else:
        L_inv = lexicon

    n_batch = len(ys)
    indices = torch.tensor(range(n_batch))

    transcripts = k2.create_fsa_vec([k2.linear_fsa(x) for x in ys])
    transcripts_lexicon = k2.intersect(transcripts, L_inv)
    transcripts_lexicon = k2.arc_sort(k2.connect(transcripts_lexicon))
    transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon)
    transcripts_lexicon = k2.shortest_path(transcripts_lexicon,
                                           use_double_scores=True)

    ys = get_texts(transcripts_lexicon, indices)
    ys = [torch.tensor(y) for y in ys]

    return ys
Exemple #12
0
    def test_case1(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda'))

        for device in devices:
            # suppose we have four symbols: <blk>, a, b, c, d
            torch_activation = torch.tensor([0.2, 0.2, 0.2, 0.2,
                                             0.2]).to(device)
            k2_activation = torch_activation.detach().clone()

            # (T, N, C)
            torch_activation = torch_activation.reshape(
                1, 1, -1).requires_grad_(True)

            # (N, T, C)
            k2_activation = k2_activation.reshape(1, 1,
                                                  -1).requires_grad_(True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)

            # we have only one sequence and its label is `a`
            targets = torch.tensor([1]).to(device)
            input_lengths = torch.tensor([1]).to(device)
            target_lengths = torch.tensor([1]).to(device)
            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            assert torch.allclose(torch_loss,
                                  torch.tensor([1.6094379425049]).to(device))

            # (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(k2_activation,
                                                           dim=-1)

            supervision_segments = torch.tensor([[0, 0, 1]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo([0, 1, 2, 3, 4]).invert_())
            linear_fsa = k2.linear_fsa([1])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)

            torch_loss.backward()
            (-k2_scores).backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Exemple #13
0
 def test1(self):
     s = '''
         0 4 1 1
         0 1 1 1
         1 2 2 2
         1 3 3 3
         2 7 1 4
         3 7 1 5
         4 6 1 2
         4 6 1 3
         4 5 1 3
         4 8 -1 2
         5 8 -1 4
         6 8 -1 3
         7 8 -1 5
         8
     '''
     fsa = k2.Fsa.from_str(s)
     prop = fsa.properties
     self.assertFalse(
         prop & k2.fsa_properties.ARC_SORTED_AND_DETERMINISTIC != 0)
     dest = k2.determinize(fsa)
     log_semiring = False
     self.assertTrue(k2.is_rand_equivalent(fsa, dest, log_semiring))
     arc_sorted = k2.arc_sort(dest)
     prop = arc_sorted.properties
     self.assertTrue(
         prop & k2.fsa_properties.ARC_SORTED_AND_DETERMINISTIC != 0)
Exemple #14
0
def build_ctc_topo(tokens: List[int]) -> k2.Fsa:
    '''Build CTC topology.
    The resulting topology converts repeated input
    symbols to a single output symbol.
    Caution:
      The resulting topo is an FST. Epsilons are on the left
      side (i.e., ilabels) and tokens are on the right side (i.e., olabels)
    Args:
      tokens:
        A list of tokens, e.g., phones, characters, etc.
    Returns:
      Returns an FST that converts repeated tokens to a single token.
    '''
    assert 0 in tokens, 'We assume 0 is ID of the blank symbol'

    num_states = len(tokens)
    final_state = num_states
    rules = ''
    for i in range(num_states):
        for j in range(num_states):
            if i == j:
                rules += f'{i} {i} 0 {tokens[i]} 0.0\n'
            else:
                rules += f'{i} {j} {tokens[j]} {tokens[j]} 0.0\n'
        rules += f'{i} {final_state} -1 -1 0.0\n'
    rules += f'{final_state}'
    ans = k2.Fsa.from_str(rules)
    ans = k2.arc_sort(ans)
    return ans
Exemple #15
0
    def compile(self, texts: Iterable[str],
                P: k2.Fsa) -> Tuple[k2.Fsa, k2.Fsa]:
        '''Create numerator and denominator graphs from transcripts
        and the bigram phone LM.

        Args:
          texts:
            A list of transcripts. Within a transcript, words are
            separated by spaces.
          P:
            The bigram phone LM created by :func:`create_bigram_phone_lm`.
        Returns:
          A tuple (num_graph, den_graph), where

            - `num_graph` is the numerator graph. It is an FsaVec with
              shape `(len(texts), None, None)`.

            - `den_graph` is the denominator graph. It is an FsaVec with the same
              shape of the `num_graph`.
        '''
        assert P.is_cpu()

        ctc_topo_P = k2.intersect(self.ctc_topo, P).invert_()
        ctc_topo_P = k2.connect(ctc_topo_P)

        num_graphs = k2.create_fsa_vec(
            [self.compile_one_and_cache(text) for text in texts])

        num = k2.compose(ctc_topo_P, num_graphs)
        num = k2.connect(num)
        num = k2.arc_sort(num)

        den = k2.create_fsa_vec([ctc_topo_P.detach()] * len(texts))

        return num, den
Exemple #16
0
def build_ctc_topo(tokens: List[int]) -> k2.Fsa:
    """Build CTC topology.

    A token which appears once on the right side (i.e. olabels) may
    appear multiple times on the left side (ilabels), possibly with
    epsilons in between.
    When 0 appears on the left side, it represents the blank symbol;
    when it appears on the right side, it indicates an epsilon. That
    is, 0 has two meanings here.
    Args:
      tokens:
        A list of tokens, e.g., phones, characters, etc.
    Returns:
      Returns an FST that converts repeated tokens to a single token.
    """
    assert 0 in tokens, "We assume 0 is ID of the blank symbol"

    num_states = len(tokens)
    final_state = num_states
    arcs = ""
    for i in range(num_states):
        for j in range(num_states):
            if i == j:
                arcs += f"{i} {i} {tokens[i]} 0 0.0\n"
            else:
                arcs += f"{i} {j} {tokens[j]} {tokens[j]} 0.0\n"
        arcs += f"{i} {final_state} -1 -1 0.0\n"
    arcs += f"{final_state}"
    ans = k2.Fsa.from_str(arcs, num_aux_labels=1)
    ans = k2.arc_sort(ans)
    return ans
Exemple #17
0
def build_shared_blank_topo(tokens: List[int], with_self_loops: bool = True) -> 'k2.Fsa':
    """Build the shared blank CTC topology.
    See https://github.com/k2-fsa/k2/issues/746#issuecomment-856421616
    """
    assert 0 in tokens, "We assume 0 is the ID of the blank symbol"

    tokens = tokens.copy()
    tokens.remove(0)
    num_tokens = len(tokens)
    start = 0
    final = num_tokens + 1
    arcs = []
    arcs.append([start, start, 0, 0, 0])
    arcs.append([start, final, -1, -1, 0])
    arcs.append([final])
    for i, p in enumerate(tokens):
        i += 1
        arcs.append([start, start, p, p, 0])
        arcs.append([start, i, p, p, 0])
        arcs.append([i, start, p, 0, 0])
        if with_self_loops:
            arcs.append([i, i, p, 0, 0])
    arcs = sorted(arcs, key=lambda arc: arc[0])
    arcs = [[str(i) for i in arc] for arc in arcs]
    arcs = [" ".join(arc) for arc in arcs]
    arcs = "\n".join(arcs)
    ans = k2.Fsa.from_str(arcs, num_aux_labels=1)
    ans = k2.arc_sort(ans)
    return ans
Exemple #18
0
    def build_num_graphs(self, texts: List[str]) -> k2.Fsa:
        '''Convert transcript to an Fsa with the help of lexicon
        and word symbol table.

        Args:
          texts:
            Each element is a transcript containing words separated by spaces.
            For instance, it may be 'HELLO SNOWFALL', which contains
            two words.

        Returns:
          Return an FST (FsaVec) corresponding to the transcript. Its `labels` are
          phone IDs and `aux_labels` are word IDs.
        '''
        word_ids_list = []
        for text in texts:
            word_ids = []
            for word in text.split(' '):
                if word in self.words:
                    word_ids.append(self.words[word])
                else:
                    word_ids.append(self.oov_id)
            word_ids_list.append(word_ids)

        fsa = k2.linear_fsa(word_ids_list, self.device)
        fsa = k2.add_epsilon_self_loops(fsa)
        num_graphs = k2.intersect(self.L_inv,
                                  fsa,
                                  treat_epsilons_specially=False).invert_()
        num_graphs = k2.arc_sort(num_graphs)
        return num_graphs
Exemple #19
0
    def __init__(
        self,
        asr_train_config: Union[Path, str],
        asr_model_file: Union[Path, str] = None,
        lm_train_config: Union[Path, str] = None,
        lm_file: Union[Path, str] = None,
        token_type: str = None,
        bpemodel: str = None,
        device: str = "cpu",
        maxlenratio: float = 0.0,
        minlenratio: float = 0.0,
        batch_size: int = 1,
        dtype: str = "float32",
        beam_size: int = 8,
        ctc_weight: float = 0.5,
        lm_weight: float = 1.0,
        penalty: float = 0.0,
        nbest: int = 1,
        streaming: bool = False,
        output_beam_size: int = 8,
    ):
        assert check_argument_types()

        # 1. Build ASR model
        asr_model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, device)
        asr_model.to(dtype=getattr(torch, dtype)).eval()

        token_list = asr_model.token_list
        self.decode_graph = k2.arc_sort(
            build_ctc_topo(list(range(len(token_list))))).to(device)

        if token_type is None:
            token_type = asr_train_args.token_type
        if bpemodel is None:
            bpemodel = asr_train_args.bpemodel

        if token_type is None:
            tokenizer = None
        elif token_type == "bpe":
            if bpemodel is not None:
                tokenizer = build_tokenizer(token_type=token_type,
                                            bpemodel=bpemodel)
            else:
                tokenizer = None
        else:
            tokenizer = build_tokenizer(token_type=token_type)
        converter = TokenIDConverter(token_list=token_list)
        logging.info(f"Text tokenizer: {tokenizer}")
        logging.info(f"Running on : {device}")

        self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.converter = converter
        self.tokenizer = tokenizer
        self.device = device
        self.dtype = dtype
        self.output_beam_size = output_beam_size
Exemple #20
0
def compile_LG(L: Fsa, G: Fsa, labels_disambig_id_start: int,
               aux_labels_disambig_id_start: int) -> Fsa:
    """
    Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``.
    Involves arc sorting, intersection, determinization, removal of disambiguation symbols
    and adding epsilon self-loops.

    Args:
        L:
            An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols``
                and words as ``aux_symbols``.
        G:
            An ``Fsa`` that represents the language model (G), i.e. it's an acceptor
            with words as ``symbols``.
        labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            phonetic alphabet.
        aux_labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            words vocabulary.
    :return:
    """
    L_inv = k2.arc_sort(L.invert_())
    G = k2.arc_sort(G)
    logging.debug("Intersecting L and G")
    LG = k2.intersect(L_inv, G)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting L*G")
    LG = k2.connect(LG).invert_()
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Determinizing L*G")
    LG = k2.determinize(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting det(L*G)")
    LG = k2.connect(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Removing disambiguation symbols on L*G")
    LG.labels[LG.labels >= labels_disambig_id_start] = 0
    LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0
    LG = k2.add_epsilon_self_loops(LG)
    LG = k2.arc_sort(LG)
    logging.debug(
        f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}'
    )
    return LG
Exemple #21
0
    def test_case3(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda'))

        for device in devices:
            # (T, N, C)
            torch_activation = torch.tensor([[
                [-5, -4, -3, -2, -1],
                [-10, -9, -8, -7, -6],
                [-15, -14, -13, -12, -11.],
            ]]).permute(1, 0, 2).to(device).requires_grad_(True)
            torch_activation = torch_activation.to(torch.float32)
            torch_activation.requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)
            # we have only one sequence and its labels are `b,c`
            targets = torch.tensor([2, 3]).to(device)
            input_lengths = torch.tensor([3]).to(device)
            target_lengths = torch.tensor([2]).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            act = k2_activation.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1)

            supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo([0, 1, 2, 3, 4]).invert_())
            linear_fsa = k2.linear_fsa([2, 3])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)
            assert torch.allclose(torch_loss,
                                  torch.tensor([4.938850402832]).to(device))

            torch_loss.backward()
            (-k2_scores).backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Exemple #22
0
 def compile_one_and_cache(self, text: str) -> k2.Fsa:
     tokens = (token if token in self.words else self.oov
               for token in text.split(' '))
     word_ids = [self.words[token] for token in tokens]
     fsa = k2.linear_fsa(word_ids)
     decoding_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_()
     decoding_graph = k2.arc_sort(decoding_graph)
     decoding_graph = k2.compose(self.ctc_topo, decoding_graph)
     decoding_graph = k2.connect(decoding_graph)
     return decoding_graph
Exemple #23
0
    def test_random_case1(self):
        # 1 sequence
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))

        for device in devices:
            T = torch.randint(10, 100, (1,)).item()
            C = torch.randint(20, 30, (1,)).item()
            torch_activation = torch.rand((1, T + 10, C),
                                          dtype=torch.float32,
                                          device=device).requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            # [N, T, C] -> [T, N, C]
            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation.permute(1, 0, 2), dim=-1)

            input_lengths = torch.tensor([T]).to(device)
            target_lengths = torch.randint(1, T, (1,)).to(device)
            targets = torch.randint(1, C - 1,
                                    (target_lengths.item(),)).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')
            k2_log_probs = torch.nn.functional.log_softmax(k2_activation,
                                                           dim=-1)
            supervision_segments = torch.tensor([[0, 0, T]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)
            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo(list(range(C))).invert_())
            linear_fsa = k2.linear_fsa([targets.tolist()])

            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)
            scale = torch.rand_like(torch_loss) * 100
            (torch_loss * scale).sum().backward()
            (-k2_scores * scale).sum().backward()
            assert torch.allclose(torch_activation.grad,
                                  k2_activation.grad,
                                  atol=1e-2)
Exemple #24
0
    def test_composition_equivalence(self):
        index = _generate_fsa_vec()
        index = k2.arc_sort(k2.connect(k2.remove_epsilon(index)))

        src = _generate_fsa_vec()

        replace = k2.replace_fsa(src, index, 1)
        replace = k2.top_sort(replace)

        f_fsa = _construct_f(src)
        f_fsa = k2.arc_sort(f_fsa)
        intersect = k2.intersect(index, f_fsa, treat_epsilons_specially=True)
        intersect = k2.invert(intersect)
        intersect = k2.top_sort(intersect)
        delattr(intersect, 'aux_labels')

        assert k2.is_rand_equivalent(replace,
                                     intersect,
                                     log_semiring=True,
                                     delta=1e-3)
Exemple #25
0
    def test_arc_sort(self):
        s = r'''
        0 1 2
        0 4 0
        0 2 0
        1 2 1
        1 3 0
        2 1 0
        4
        '''

        fsa = k2.str_to_fsa(s)
        arc_map = k2.IntArray1.create_array_with_size(fsa.size2)
        k2.arc_sort(fsa, arc_map)
        expected_arc_indexes = torch.IntTensor([0, 3, 5, 6, 6, 6])
        expected_arcs = torch.IntTensor([[0, 2, 0], [0, 4, 0], [0, 1, 2],
                                         [1, 3, 0], [1, 2, 1], [2, 1, 0]])
        expected_arc_map = torch.IntTensor([2, 1, 0, 4, 3, 5])
        self.assertTrue(torch.equal(fsa.indexes, expected_arc_indexes))
        self.assertTrue(torch.equal(fsa.data, expected_arcs))
        self.assertTrue(torch.equal(arc_map.data, expected_arc_map))
Exemple #26
0
def compute_am_scores(lats: k2.Fsa, word_fsas_with_epsilon_loops: k2.Fsa,
                      path_to_seq_map: torch.Tensor) -> torch.Tensor:
    '''Compute AM scores of n-best lists (represented as word_fsas).

    Args:
      lats:
        An FsaVec, which is the output of `k2.intersect_dense_pruned`.
        It must have the attribute `lm_scores`.
      word_fsas_with_epsilon_loops:
        An FsaVec representing a n-best list. Note that it has been processed
        by `k2.add_epsilon_self_loops`.
      path_to_seq_map:
        A 1-D torch.Tensor with dtype torch.int32. path_to_seq_map[i] indicates
        which sequence the i-th Fsa in word_fsas_with_epsilon_loops belongs to.
        path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0().
    Returns:
      Return a 1-D torch.Tensor containing the AM scores of each path.
      `ans.numel() == word_fsas_with_epsilon_loops.shape[0]`
    '''
    device = lats.device
    assert len(lats.shape) == 3
    assert hasattr(lats, 'lm_scores')

    # k2.compose() currently does not support b_to_a_map. To void
    # replicating `lats`, we use k2.intersect_device here.
    #
    # lats has phone IDs as `labels` and word IDs as aux_labels, so we
    # need to invert it here.
    inverted_lats = k2.invert(lats)

    # Now the `labels` of inverted_lats are word IDs (a 1-D torch.Tensor)
    # and its `aux_labels` are phone IDs ( a k2.RaggedInt with 2 axes)

    # Remove its `aux_labels` since it is not needed in the
    # following computation
    del inverted_lats.aux_labels
    inverted_lats = k2.arc_sort(inverted_lats)

    am_path_lats = _intersect_device(inverted_lats,
                                     word_fsas_with_epsilon_loops,
                                     b_to_a_map=path_to_seq_map,
                                     sorted_match_a=True)

    # NOTE: `k2.connect` and `k2.top_sort` support only CPU at present
    am_path_lats = k2.top_sort(k2.connect(am_path_lats.to('cpu'))).to(device)

    # The `scores` of every arc consists of `am_scores` and `lm_scores`
    am_path_lats.scores = am_path_lats.scores - am_path_lats.lm_scores

    am_scores = am_path_lats.get_tot_scores(True, True)

    return am_scores
Exemple #27
0
def create_decoding_graph(texts, graph, symbols):
    fsas = []
    for text in texts:
        filter_text = [
            i if i in symbols._sym2id else '<UNK>' for i in text.split(' ')
        ]
        word_ids = [symbols.get(i) for i in filter_text]
        fsa = k2.linear_fsa(word_ids)
        fsa = k2.arc_sort(fsa)
        decoding_graph = k2.intersect(fsa, graph).invert_()
        decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
        fsas.append(decoding_graph)
    return k2.create_fsa_vec(fsas)
Exemple #28
0
    def test_case2(self):
        for device in self.devices:
            # (T, N, C)
            torch_activation = torch.arange(1, 16).reshape(1, 3, 5).permute(
                1, 0, 2).to(device)
            torch_activation = torch_activation.to(torch.float32)
            torch_activation.requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)
            # we have only one sequence and its labels are `c,c`
            targets = torch.tensor([3, 3]).to(device)
            input_lengths = torch.tensor([3]).to(device)
            target_lengths = torch.tensor([2]).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            act = k2_activation.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1)

            supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo([0, 1, 2, 3, 4]).invert_())
            linear_fsa = k2.linear_fsa([3, 3])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)
            assert torch.allclose(torch_loss,
                                  torch.tensor([7.355742931366]).to(device))

            torch_loss.backward()
            (-k2_scores).backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Exemple #29
0
    def compile(self, targets: torch.Tensor,
                target_lengths: torch.Tensor) -> 'k2.Fsa':
        token_ids_list = [
            t[:l].tolist() for t, l in zip(targets, target_lengths)
        ]
        # see https://github.com/k2-fsa/k2/issues/835
        label_graph = k2.linear_fsa(token_ids_list).to(self.device)
        label_graph.aux_labels = label_graph.labels.clone()
        decoding_graphs = compose_with_self_loops(self.base_graph, label_graph)
        decoding_graphs = k2.arc_sort(decoding_graphs).to(self.device)

        # make sure the gradient is not accumulated
        decoding_graphs.requires_grad_(False)
        return decoding_graphs
Exemple #30
0
    def __init__(self,
                 lexicon: Lexicon,
                 device: torch.device,
                 oov: str = '<UNK>'):
        '''
        Args:
          L_inv:
            Its labels are words, while its aux_labels are phones.
        phones:
          The phone symbol table.
        words:
          The word symbol table.
        oov:
          Out of vocabulary word.
        '''
        self.lexicon = lexicon
        L_inv = self.lexicon.L_inv.to(device)

        if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0:
            L_inv = k2.arc_sort(L_inv)

        assert L_inv.requires_grad is False

        assert oov in self.lexicon.words

        self.L_inv = L_inv
        self.oov_id = self.lexicon.words[oov]
        self.oov = oov
        self.device = device

        phone_symbols = get_phone_symbols(self.lexicon.phones)
        phone_symbols_with_blank = [0] + phone_symbols

        ctc_topo = build_ctc_topo(phone_symbols_with_blank).to(device)
        assert ctc_topo.requires_grad is False

        self.ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())