Example #1
0
 def test(self):
     for device in self.devices:
         s = '''
             0 1 1 0.2
             0 1 2 0.1
             1 2 2 0.1
             1 2 3 0.2
             2 3 -1 0
             3
         '''
         fsa = k2.Fsa.from_str(s)
         fsa_vec = k2.create_fsa_vec([fsa])
         threshold_prob = 0.5
         ans = k2.prune_on_arc_post(fsa_vec,
                                    threshold_prob,
                                    use_double_scores=True)
         expected = k2.Fsa.from_str('''
             0 1 1 0.2
             1 2 3 0.2
             2 3 -1 0
             3
         ''')
         assert str(ans[0]) == str(expected)
Example #2
0
def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa,
                               lm_scale_list: List[float]
                              ) -> Dict[str, k2.Fsa]:
    '''Use whole lattice to rescore.

    Args:
      lats:
        An FsaVec It can be the output of `k2.intersect_dense_pruned`.
      G_with_epsilon_loops:
        An FsaVec representing the language model (LM). Note that it
        is an FsaVec, but it contains only one Fsa.
      lm_scale_list:
        A list containing lm_scale values.
    Returns:
      A dict of FsaVec, whose key is a lm_scale and the value represents the
      best decoding path for each sequence in the lattice.
    '''
    assert len(lats.shape) == 3
    assert hasattr(lats, 'lm_scores')
    assert G_with_epsilon_loops.shape == (1, None, None)

    device = lats.device
    lats.scores = lats.scores - lats.lm_scores
    # We will use lm_scores from G, so remove lats.lm_scores here
    del lats.lm_scores
    assert hasattr(lats, 'lm_scores') is False

    #  lats.scores = scores / lm_scale
    # Now, lats.scores contains only am_scores

    # inverted_lats has word IDs as labels.
    # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt
    inverted_lats = k2.invert(lats)
    num_seqs = lats.shape[0]

    b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
    try:
        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats,
                                             b_to_a_map,
                                             sorted_match_a=True)
    except RuntimeError as e:
        print(f'Caught exception:\n{e}\n')
        print(f'Number of FSAs: {inverted_lats.shape[0]}')
        print('num_arcs before pruning: ', inverted_lats.arcs.num_elements())

        # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here
        # to avoid OOM. We may need to fine tune it.
        inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True)
        print('num_arcs after pruning: ', inverted_lats.arcs.num_elements())

        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats,
                                             b_to_a_map,
                                             sorted_match_a=True)

    rescoring_lats = k2.top_sort(k2.connect(rescoring_lats.to('cpu')).to(device))

    # inv_lats has phone IDs as labels
    # and word IDs as aux_labels.
    inv_lats = k2.invert(rescoring_lats)

    ans = dict()
    #
    # The following implements
    # scores = (scores - lm_scores)/lm_scale + lm_scores
    #        = scores/lm_scale + lm_scores*(1 - 1/lm_scale)
    #
    saved_scores = inv_lats.scores.clone()
    for lm_scale in lm_scale_list:
        am_scores = saved_scores - inv_lats.lm_scores
        am_scores /= lm_scale
        inv_lats.scores = am_scores + inv_lats.lm_scores

        best_paths = k2.shortest_path(inv_lats, use_double_scores=True)
        key = f'lm_scale_{lm_scale}'
        ans[key] = best_paths
    return ans
Example #3
0
def rescore_with_whole_lattice(lats: k2.Fsa,
                               G_with_epsilon_loops: k2.Fsa) -> k2.Fsa:
    '''Use whole lattice to rescore.

    Args:
      lats:
        An FsaVec It can be the output of `k2.intersect_dense_pruned`.
      G_with_epsilon_loops:
        An FsaVec representing the language model (LM). Note that it
        is an FsaVec, but it contains only one Fsa.
    '''
    assert len(lats.shape) == 3
    assert hasattr(lats, 'lm_scores')
    assert G_with_epsilon_loops.shape == (1, None, None)

    device = lats.device
    lats.scores = lats.scores - lats.lm_scores
    # Now, lats.scores contains only am_scores

    # inverted_lats has word IDs as labels.
    # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt
    inverted_lats = k2.invert(lats)
    num_seqs = lats.shape[0]
    inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops(inverted_lats)

    b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
    try:
        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats_with_epsilon_loops,
                                             b_to_a_map,
                                             sorted_match_a=True)
    except RuntimeError as e:
        print(f'Caught exception:\n{e}\n')
        print(f'Number of FSAs: {inverted_lats.shape[0]}')
        print('num_arcs before pruning: ',
              inverted_lats_with_epsilon_loops.arcs.num_elements())

        # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here
        # to avoid OOM. We may need to fine tune it.
        inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True)
        inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops(
            inverted_lats)
        print('num_arcs after pruning: ',
              inverted_lats_with_epsilon_loops.arcs.num_elements())

        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats_with_epsilon_loops,
                                             b_to_a_map,
                                             sorted_match_a=True)

    rescoring_lats = k2.top_sort(k2.connect(
        rescoring_lats.to('cpu'))).to(device)
    inverted_rescoring_lats = k2.invert(rescoring_lats)
    # inverted rescoring_lats has phone IDs as labels
    # and word IDs as aux_labels.

    inverted_rescoring_lats = k2.remove_epsilon_self_loops(
        inverted_rescoring_lats)
    best_paths = k2.shortest_path(inverted_rescoring_lats,
                                  use_double_scores=True)
    return best_paths
Example #4
0
def whole_lattice_rescoring(lats: Fsa, G_with_epsilon_loops: Fsa) -> Fsa:
    '''Rescore the 1st pass lattice with an LM.

    In general, the G in HLG used to obtain `lats` is a 3-gram LM.
    This function replaces the 3-gram LM in `lats` with a 4-gram LM.

    Args:
      lats:
        The decoding lattice from the 1st pass. We assume it is the result
        of intersecting HLG with the network output.
      G_with_epsilon_loops:
        An LM. It is usually a 4-gram LM with epsilon self-loops.
        It should be arc sorted.
    Returns:
      Return a new lattice rescored with a given G.
    '''
    assert len(lats.shape) == 3, f'{lats.shape}'
    assert hasattr(lats, 'lm_scores')
    assert G_with_epsilon_loops.shape == (1, None, None), \
            f'{G_with_epsilon_loops.shape}'

    device = lats.device
    lats.scores = lats.scores - lats.lm_scores
    # Now lats contains only acoustic scores

    # We will use lm_scores from the given G, so remove lats.lm_scores here
    del lats.lm_scores
    assert hasattr(lats, 'lm_scores') is False

    # inverted_lats has word IDs as labels.
    # Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt
    # if lats.aux_labels is a ragged tensor
    inverted_lats = k2.invert(lats)
    num_seqs = lats.shape[0]

    b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)

    while True:
        try:
            rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                                 inverted_lats,
                                                 b_to_a_map,
                                                 sorted_match_a=True)
            break
        except RuntimeError as e:
            logging.info(f'Caught exception:\n{e}\n')
            # Usually, this is an OOM exception. We reduce
            # the size of the lattice and redo k2.intersect_device()

            # NOTE(fangjun): The choice of the threshold 1e-5 is arbitrary here
            # to avoid OOM. We may need to fine tune it.
            logging.info(f'num_arcs before: {inverted_lats.num_arcs}')
            inverted_lats = k2.prune_on_arc_post(inverted_lats, 1e-5, True)
            logging.info(f'num_arcs after: {inverted_lats.num_arcs}')

    rescoring_lats = k2.top_sort(k2.connect(rescoring_lats))

    # inv_rescoring_lats has token IDs as labels
    # and word IDs as aux_labels.
    inv_rescoring_lats = k2.invert(rescoring_lats)
    return inv_rescoring_lats