Exemple #1
0
 def test_aux_ragged(self):
     for device in self.devices:
         s = '''
             0 1 1 0.1
             0 2 2 0.2
             1 3 3 0.3
             2 3 4 0.6
             3 4 -1 0.7
             4
         '''
         # https://git.io/JqNiR
         fsa = k2.Fsa.from_str(s).to(device)
         fsa.aux_labels = k2.RaggedInt('[[2 3] [3 4] [] [5] [-1]]').to(
             device)
         fsa.tensor_attr1 = torch.tensor([1, 2, 3, 4, 5]).to(device)
         # https://git.io/JqNiw
         ans = k2.invert(fsa)
         assert torch.all(
             torch.eq(ans.tensor_attr1,
                      torch.tensor([1, 2, 0, 0, 3, 4, 5], device=device)))
         assert torch.all(
             torch.eq(ans.aux_labels,
                      torch.tensor([1, 2, 0, 0, 3, 4, -1], device=device)))
         assert torch.all(
             torch.eq(ans.labels,
                      torch.tensor([2, 3, 3, 4, 0, 5, -1], device=device)))
Exemple #2
0
def compute_am_scores(lats: k2.Fsa, word_fsas_with_epsilon_loops: k2.Fsa,
                      path_to_seq_map: torch.Tensor) -> torch.Tensor:
    '''Compute AM scores of n-best lists (represented as word_fsas).

    Args:
      lats:
        An FsaVec, which is the output of `k2.intersect_dense_pruned`.
        It must have the attribute `lm_scores`.
      word_fsas_with_epsilon_loops:
        An FsaVec representing a n-best list. Note that it has been processed
        by `k2.add_epsilon_self_loops`.
      path_to_seq_map:
        A 1-D torch.Tensor with dtype torch.int32. path_to_seq_map[i] indicates
        which sequence the i-th Fsa in word_fsas_with_epsilon_loops belongs to.
        path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0().
    Returns:
      Return a 1-D torch.Tensor containing the AM scores of each path.
      `ans.numel() == word_fsas_with_epsilon_loops.shape[0]`
    '''
    device = lats.device
    assert len(lats.shape) == 3
    assert hasattr(lats, 'lm_scores')

    # k2.compose() currently does not support b_to_a_map. To void
    # replicating `lats`, we use k2.intersect_device here.
    #
    # lats has phone IDs as `labels` and word IDs as aux_labels, so we
    # need to invert it here.
    inverted_lats = k2.invert(lats)

    # Now the `labels` of inverted_lats are word IDs (a 1-D torch.Tensor)
    # and its `aux_labels` are phone IDs ( a k2.RaggedInt with 2 axes)

    # Remove its `aux_labels` since it is not needed in the
    # following computation
    del inverted_lats.aux_labels
    inverted_lats = k2.arc_sort(inverted_lats)

    am_path_lats = _intersect_device(inverted_lats,
                                     word_fsas_with_epsilon_loops,
                                     b_to_a_map=path_to_seq_map,
                                     sorted_match_a=True)

    # NOTE: `k2.connect` and `k2.top_sort` support only CPU at present
    am_path_lats = k2.top_sort(k2.connect(am_path_lats.to('cpu'))).to(device)

    # The `scores` of every arc consists of `am_scores` and `lm_scores`
    am_path_lats.scores = am_path_lats.scores - am_path_lats.lm_scores

    am_scores = am_path_lats.get_tot_scores(True, True)

    return am_scores
Exemple #3
0
 def test_aux_as_tensor(self):
     s = '''
         0 1 1 1 0
         0 1 0 2 0
         0 3 2 3 0
         1 2 3 4 0
         1 3 4 5 0
         2 1 5 6 0
         2 5 -1 -1 0
         3 1 6 7 0
         4 5 -1 -1 0
         5
     '''
     fsa = k2.Fsa.from_str(s, num_aux_labels=1)
     assert fsa.device.type == 'cpu'
     dest = k2.invert(fsa)
     print(dest)
Exemple #4
0
    def test_composition_equivalence(self):
        index = _generate_fsa_vec()
        index = k2.arc_sort(k2.connect(k2.remove_epsilon(index)))

        src = _generate_fsa_vec()

        replace = k2.replace_fsa(src, index, 1)
        replace = k2.top_sort(replace)

        f_fsa = _construct_f(src)
        f_fsa = k2.arc_sort(f_fsa)
        intersect = k2.intersect(index, f_fsa, treat_epsilons_specially=True)
        intersect = k2.invert(intersect)
        intersect = k2.top_sort(intersect)
        delattr(intersect, 'aux_labels')

        assert k2.is_rand_equivalent(replace,
                                     intersect,
                                     log_semiring=True,
                                     delta=1e-3)
Exemple #5
0
def _construct_f(fsa_vec: k2.Fsa) -> k2.Fsa:
    num_fsa = fsa_vec.shape[0]
    union = k2.union(fsa_vec)
    union.aux_labels = torch.zeros(union.num_arcs)
    union.aux_labels[0:num_fsa] = torch.tensor(list(range(1, 1 + num_fsa)),
                                               dtype=torch.int32)
    union_str = k2.to_str_simple(union)
    states_num = union.shape[0]

    new_str_array = []
    new_str_array.append("0 {} -1 0 0".format(states_num - 1))
    for line in union_str.strip().split("\n"):
        tokens = line.strip().split(" ")
        if len(tokens) == 5:
            tokens[1] = '0' if int(tokens[1]) == states_num - 1 else tokens[1]
            tokens[2] = '0' if int(tokens[2]) == -1 else tokens[2]
        new_str_array.append(" ".join(tokens))
    new_str = "\n".join(new_str_array)

    new_fsa = k2.Fsa.from_str(new_str, num_aux_labels=1)
    new_fsa_invert = k2.invert(new_fsa)
    return new_fsa_invert
Exemple #6
0
 def test_aux_as_ragged(self):
     s = '''
         0 1 1 0
         0 1 0 0
         0 3 2 0
         1 2 3 0
         1 3 4 0
         2 1 5 0
         2 5 -1 0
         3 1 6 0
         4 5 -1 0
         5
     '''
     fsa = k2.Fsa.from_str(s)
     assert fsa.device.type == 'cpu'
     aux_row_splits = torch.tensor([0, 2, 3, 3, 6, 6, 7, 8, 10, 11],
                                   dtype=torch.int32)
     aux_shape = k2.ragged.create_ragged_shape2(aux_row_splits, None, 11)
     aux_values = torch.tensor([1, 2, 3, 5, 6, 7, 8, -1, 9, 10, -1],
                               dtype=torch.int32)
     fsa.aux_labels = k2.RaggedInt(aux_shape, aux_values)
     dest = k2.invert(fsa)
     print(dest)  # will print aux_labels as well
Exemple #7
0
def rescore_with_whole_lattice(lats: k2.Fsa,
                               G_with_epsilon_loops: k2.Fsa) -> k2.Fsa:
    '''Use whole lattice to rescore.

    Args:
      lats:
        An FsaVec It can be the output of `k2.intersect_dense_pruned`.
      G_with_epsilon_loops:
        An FsaVec representing the language model (LM). Note that it
        is an FsaVec, but it contains only one Fsa.
    '''
    assert len(lats.shape) == 3
    assert hasattr(lats, 'lm_scores')
    assert G_with_epsilon_loops.shape == (1, None, None)

    device = lats.device
    lats.scores = lats.scores - lats.lm_scores
    # Now, lats.scores contains only am_scores

    # inverted_lats has word IDs as labels.
    # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt
    inverted_lats = k2.invert(lats)
    num_seqs = lats.shape[0]
    inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops(inverted_lats)

    b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
    try:
        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats_with_epsilon_loops,
                                             b_to_a_map,
                                             sorted_match_a=True)
    except RuntimeError as e:
        print(f'Caught exception:\n{e}\n')
        print(f'Number of FSAs: {inverted_lats.shape[0]}')
        print('num_arcs before pruning: ',
              inverted_lats_with_epsilon_loops.arcs.num_elements())

        # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here
        # to avoid OOM. We may need to fine tune it.
        inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True)
        inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops(
            inverted_lats)
        print('num_arcs after pruning: ',
              inverted_lats_with_epsilon_loops.arcs.num_elements())

        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats_with_epsilon_loops,
                                             b_to_a_map,
                                             sorted_match_a=True)

    rescoring_lats = k2.top_sort(k2.connect(
        rescoring_lats.to('cpu'))).to(device)
    inverted_rescoring_lats = k2.invert(rescoring_lats)
    # inverted rescoring_lats has phone IDs as labels
    # and word IDs as aux_labels.

    inverted_rescoring_lats = k2.remove_epsilon_self_loops(
        inverted_rescoring_lats)
    best_paths = k2.shortest_path(inverted_rescoring_lats,
                                  use_double_scores=True)
    return best_paths
Exemple #8
0
def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa,
                               lm_scale_list: List[float]
                              ) -> Dict[str, k2.Fsa]:
    '''Use whole lattice to rescore.

    Args:
      lats:
        An FsaVec It can be the output of `k2.intersect_dense_pruned`.
      G_with_epsilon_loops:
        An FsaVec representing the language model (LM). Note that it
        is an FsaVec, but it contains only one Fsa.
      lm_scale_list:
        A list containing lm_scale values.
    Returns:
      A dict of FsaVec, whose key is a lm_scale and the value represents the
      best decoding path for each sequence in the lattice.
    '''
    assert len(lats.shape) == 3
    assert hasattr(lats, 'lm_scores')
    assert G_with_epsilon_loops.shape == (1, None, None)

    device = lats.device
    lats.scores = lats.scores - lats.lm_scores
    # We will use lm_scores from G, so remove lats.lm_scores here
    del lats.lm_scores
    assert hasattr(lats, 'lm_scores') is False

    #  lats.scores = scores / lm_scale
    # Now, lats.scores contains only am_scores

    # inverted_lats has word IDs as labels.
    # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt
    inverted_lats = k2.invert(lats)
    num_seqs = lats.shape[0]

    b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
    try:
        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats,
                                             b_to_a_map,
                                             sorted_match_a=True)
    except RuntimeError as e:
        print(f'Caught exception:\n{e}\n')
        print(f'Number of FSAs: {inverted_lats.shape[0]}')
        print('num_arcs before pruning: ', inverted_lats.arcs.num_elements())

        # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here
        # to avoid OOM. We may need to fine tune it.
        inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True)
        print('num_arcs after pruning: ', inverted_lats.arcs.num_elements())

        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats,
                                             b_to_a_map,
                                             sorted_match_a=True)

    rescoring_lats = k2.top_sort(k2.connect(rescoring_lats.to('cpu')).to(device))

    # inv_lats has phone IDs as labels
    # and word IDs as aux_labels.
    inv_lats = k2.invert(rescoring_lats)

    ans = dict()
    #
    # The following implements
    # scores = (scores - lm_scores)/lm_scale + lm_scores
    #        = scores/lm_scale + lm_scores*(1 - 1/lm_scale)
    #
    saved_scores = inv_lats.scores.clone()
    for lm_scale in lm_scale_list:
        am_scores = saved_scores - inv_lats.lm_scores
        am_scores /= lm_scale
        inv_lats.scores = am_scores + inv_lats.lm_scores

        best_paths = k2.shortest_path(inv_lats, use_double_scores=True)
        key = f'lm_scale_{lm_scale}'
        ans[key] = best_paths
    return ans
Exemple #9
0
import k2
s = '''
0 1 2 10 0.1
1 2 -1 -1 0.2
2
'''
fsa = k2.Fsa.from_str(s)
inverted_fsa = k2.invert(fsa)
fsa.draw('before_invert.svg', title='before invert')
inverted_fsa.draw('after_invert.svg', title='after invert')
def nbest_decoding(lats: k2.Fsa, num_paths: int):
    '''
    (Ideas of this function are from Dan)

    It implements something like CTC prefix beam search using n-best lists

    The basic idea is to first extra n-best paths from the given lattice,
    build a word seqs from these paths, and compute the total scores
    of these sequences in the log-semiring. The one with the max score
    is used as the decoding output.
    '''

    # First, extract `num_paths` paths for each sequence.
    # paths is a k2.RaggedInt with axes [seq][path][arc_pos]
    paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)

    # word_seqs is a k2.RaggedInt sharing the same shape as `paths`
    # but it contains word IDs. Note that it also contains 0s and -1s.
    # The last entry in each sublist is -1.

    word_seqs = k2.index(lats.aux_labels, paths)
    # Note: the above operation supports also the case when
    # lats.aux_labels is a ragged tensor. In that case,
    # `remove_axis=True` is used inside the pybind11 binding code,
    # so the resulting `word_seqs` still has 3 axes, like `paths`.
    # The 3 axes are [seq][path][word]

    # Remove epsilons and -1 from word_seqs
    word_seqs = k2.ragged.remove_values_leq(word_seqs, 0)

    # Remove repeated sequences to avoid redundant computation later.
    #
    # Since k2.ragged.unique_sequences will reorder paths within a seq,
    # `new2old` is a 1-D torch.Tensor mapping from the output path index
    # to the input path index.
    # new2old.numel() == unique_word_seqs.num_elements()
    unique_word_seqs, _, new2old = k2.ragged.unique_sequences(
        word_seqs, need_num_repeats=False, need_new2old_indexes=True)
    # Note: unique_word_seqs still has the same axes as word_seqs

    seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0)

    # path_to_seq_map is a 1-D torch.Tensor.
    # path_to_seq_map[i] is the seq to which the i-th path
    # belongs.
    path_to_seq_map = seq_to_path_shape.row_ids(1)

    # Remove the seq axis.
    # Now unique_word_seqs has only two axes [path][word]
    unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0)

    # word_fsas is an FsaVec with axes [path][state][arc]
    word_fsas = k2.linear_fsa(unique_word_seqs)

    word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas)

    # lats has phone IDs as labels and word IDs as aux_labels.
    # inv_lats has word IDs as labels and phone IDs as aux_labels
    inv_lats = k2.invert(lats)
    inv_lats = k2.arc_sort(inv_lats)  # no-op if inv_lats is already arc-sorted

    path_lats = k2.intersect_device(inv_lats,
                                    word_fsas_with_epsilon_loops,
                                    b_to_a_map=path_to_seq_map,
                                    sorted_match_a=True)
    # path_lats has word IDs as labels and phone IDs as aux_labels

    path_lats = k2.top_sort(k2.connect(path_lats.to('cpu')).to(lats.device))

    tot_scores = path_lats.get_tot_scores(True, True)
    # RaggedFloat currently supports float32 only.
    # We may bind Ragged<double> as RaggedDouble if needed.
    ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape,
                                       tot_scores.to(torch.float32))

    argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)

    # Since we invoked `k2.ragged.unique_sequences`, which reorders
    # the index from `paths`, we use `new2old`
    # here to convert argmax_indexes to the indexes into `paths`.
    #
    # Use k2.index here since argmax_indexes' dtype is torch.int32
    best_path_indexes = k2.index(new2old, argmax_indexes)

    paths_2axes = k2.ragged.remove_axis(paths, 0)

    # best_paths is a k2.RaggedInt with 2 axes [path][arc_pos]
    best_paths = k2.index(paths_2axes, best_path_indexes)

    # labels is a k2.RaggedInt with 2 axes [path][phone_id]
    # Note that it contains -1s.
    labels = k2.index(lats.labels.contiguous(), best_paths)

    labels = k2.ragged.remove_values_eq(labels, -1)

    # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so
    # aux_labels is also a k2.RaggedInt with 2 axes
    aux_labels = k2.index(lats.aux_labels, best_paths.values())

    best_path_fsas = k2.linear_fsa(labels)
    best_path_fsas.aux_labels = aux_labels

    return best_path_fsas
Exemple #11
0
def whole_lattice_rescoring(lats: Fsa, G_with_epsilon_loops: Fsa) -> Fsa:
    '''Rescore the 1st pass lattice with an LM.

    In general, the G in HLG used to obtain `lats` is a 3-gram LM.
    This function replaces the 3-gram LM in `lats` with a 4-gram LM.

    Args:
      lats:
        The decoding lattice from the 1st pass. We assume it is the result
        of intersecting HLG with the network output.
      G_with_epsilon_loops:
        An LM. It is usually a 4-gram LM with epsilon self-loops.
        It should be arc sorted.
    Returns:
      Return a new lattice rescored with a given G.
    '''
    assert len(lats.shape) == 3, f'{lats.shape}'
    assert hasattr(lats, 'lm_scores')
    assert G_with_epsilon_loops.shape == (1, None, None), \
            f'{G_with_epsilon_loops.shape}'

    device = lats.device
    lats.scores = lats.scores - lats.lm_scores
    # Now lats contains only acoustic scores

    # We will use lm_scores from the given G, so remove lats.lm_scores here
    del lats.lm_scores
    assert hasattr(lats, 'lm_scores') is False

    # inverted_lats has word IDs as labels.
    # Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt
    # if lats.aux_labels is a ragged tensor
    inverted_lats = k2.invert(lats)
    num_seqs = lats.shape[0]

    b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)

    while True:
        try:
            rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                                 inverted_lats,
                                                 b_to_a_map,
                                                 sorted_match_a=True)
            break
        except RuntimeError as e:
            logging.info(f'Caught exception:\n{e}\n')
            # Usually, this is an OOM exception. We reduce
            # the size of the lattice and redo k2.intersect_device()

            # NOTE(fangjun): The choice of the threshold 1e-5 is arbitrary here
            # to avoid OOM. We may need to fine tune it.
            logging.info(f'num_arcs before: {inverted_lats.num_arcs}')
            inverted_lats = k2.prune_on_arc_post(inverted_lats, 1e-5, True)
            logging.info(f'num_arcs after: {inverted_lats.num_arcs}')

    rescoring_lats = k2.top_sort(k2.connect(rescoring_lats))

    # inv_rescoring_lats has token IDs as labels
    # and word IDs as aux_labels.
    inv_rescoring_lats = k2.invert(rescoring_lats)
    return inv_rescoring_lats
Exemple #12
0
def compute_am_scores_and_lm_scores(
    lats: k2.Fsa,
    word_fsas_with_epsilon_loops: k2.Fsa,
    path_to_seq_map: torch.Tensor,
    device: str = "cuda",
    batch_size: int = 500,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Compute AM and LM scores of n-best lists (represented as word_fsas).

    Args:
      lats:
        An FsaVec, which is the output of `k2.intersect_dense_pruned`.
        It must have the attribute `lm_scores`.
      word_fsas_with_epsilon_loops:
        An FsaVec representing a n-best list. Note that it has been processed
        by `k2.add_epsilon_self_loops`.
      path_to_seq_map:
        A 1-D torch.Tensor with dtype torch.int32. path_to_seq_map[i] indicates
        which sequence the i-th Fsa in word_fsas_with_epsilon_loops belongs to.
        path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0().
      batch_size:
        Batchify the n-best list when intersecting with inverted_lats.
        You could tune this to avoid GPU OOM issue or increase the GPU usage.
    Returns:
      Return a tuple of (1-D torch.Tensor, 1-D torch.Tensor) containing
      the AM and LM scores of each path.
      `am_scores.numel() == word_fsas_with_epsilon_loops.shape[0]`
      `lm_scores.numel() == word_fsas_with_epsilon_loops.shape[0]`
    """
    assert len(lats.shape) == 3

    # k2.compose() currently does not support b_to_a_map. To void
    # replicating `lats`, we use k2.intersect_device here.
    #
    # lats has phone IDs as `labels` and word IDs as aux_labels, so we
    # need to invert it here.
    inverted_lats = k2.invert(lats)

    # Now the `labels` of inverted_lats are word IDs (a 1-D torch.Tensor)
    # and its `aux_labels` are phone IDs ( a k2.RaggedInt with 2 axes)

    # Remove its `aux_labels` since it is not needed in the
    # following computation
    del inverted_lats.aux_labels
    inverted_lats = k2.arc_sort(inverted_lats)

    am_path_lats = _intersect_device(
        inverted_lats,
        word_fsas_with_epsilon_loops,
        b_to_a_map=path_to_seq_map,
        sorted_match_a=True,
        batch_size=batch_size,
    )

    am_path_lats = k2.top_sort(k2.connect(am_path_lats))

    # The `scores` of every arc consists of `am_scores` and `lm_scores`
    tot_score_device = "cpu"
    if hasattr(lats, "lm_scores"):
        am_path_lats.scores = am_path_lats.scores - am_path_lats.lm_scores
        am_scores = (
            am_path_lats.to(tot_score_device)
            .get_tot_scores(use_double_scores=True, log_semiring=False)
            .to(device)
        )

        # Start to compute lm_scores
        am_path_lats.scores = am_path_lats.lm_scores
        lm_scores = (
            am_path_lats.to(tot_score_device)
            .get_tot_scores(use_double_scores=True, log_semiring=False)
            .to(device)
        )
    else:
        am_scores = (
            am_path_lats.to(tot_score_device)
            .get_tot_scores(use_double_scores=True, log_semiring=False)
            .to(device)
        )
        lm_scores = None

    return am_scores, lm_scores