Пример #1
0
def generate_nbest_list(lats: Fsa, num_paths: int) -> Nbest:
    '''Generate an n-best list from a lattice.

    Args:
      lats:
        The decoding lattice from the first pass after LM rescoring.
        lats is an FsaVec. It can be the return value of
        :func:`whole_lattice_rescoring`
      num_paths:
        Size of n for n-best list. CAUTION: After removing paths
        that represent the same token sequences, the number of paths
        in different sequences may not be equal.
    Return:
      Return an Nbest object. Note the returned FSAs don't have epsilon
      self-loops.
    '''
    assert len(lats.shape) == 3

    # CAUTION: We use `phones` instead of `tokens` here because
    # :func:`compile_HLG` uses `phones`
    #
    # Note: compile_HLG is from k2-fsa/snowfall
    assert hasattr(lats, 'phones')

    assert not hasattr(lats, 'tokens')
    lats.tokens = lats.phones
    # we use tokens instead of phones in the following code

    # First, extract `num_paths` paths for each sequence.
    # paths is a k2.RaggedTensor with axes [seq][path][arc_pos]
    paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)

    # token_seqs is a k2.RaggedTensor sharing the same shape as `paths`
    # but it contains token IDs. Note that it also contains 0s and -1s.
    # The last entry in each sublist is -1.
    # Its axes are [seq][path][token_id]
    token_seqs = k2.ragged.index(lats.tokens, paths)

    # Remove epsilons (0s) and -1 from token_seqs
    token_seqs = token_seqs.remove_values_leq(0)

    # unique_token_seqs is still a k2.RaggedTensor with axes
    # [seq][path]token_id].
    # But then number of paths in each sequence may be different.
    unique_token_seqs, _, _ = token_seqs.unique(need_num_repeats=False,
                                                need_new2old_indexes=False)

    seq_to_path_shape = unique_token_seqs.shape.get_layer(0)

    # Remove the seq axis.
    # Now unique_token_seqs has only two axes [path][token_id]
    unique_token_seqs = unique_token_seqs.remove_axis(0)

    token_fsas = k2.linear_fsa(unique_token_seqs)

    return Nbest(fsa=token_fsas, shape=seq_to_path_shape)
Пример #2
0
def nbest_am_lm_scores(
    lats: k2.Fsa,
    num_paths: int,
    device: str = "cuda",
    batch_size: int = 500,
):
    """Compute am scores with word_seqs

    Compatible with both ctc_decoding or TLG decoding.
    """
    paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)
    if isinstance(lats.aux_labels, torch.Tensor):
        word_seqs = k2.ragged.index(lats.aux_labels.contiguous(), paths)
    else:
        # '_k2.RaggedInt' object has no attribute 'contiguous'
        word_seqs = lats.aux_labels.index(paths)
        word_seqs = word_seqs.remove_axis(word_seqs.num_axes - 2)

    # With ctc_decoding, word_seqs stores token_ids.
    # With TLG decoding, word_seqs stores word_ids.
    word_seqs = word_seqs.remove_values_leq(0)
    unique_word_seqs, num_repeats, new2old = word_seqs.unique(
        need_num_repeats=True, need_new2old_indexes=True
    )

    seq_to_path_shape = unique_word_seqs.shape.get_layer(0)
    path_to_seq_map = seq_to_path_shape.row_ids(1)
    # used to split final computed tot_scores
    seq_to_path_splits = seq_to_path_shape.row_splits(1)

    unique_word_seqs = unique_word_seqs.remove_axis(0)
    word_fsas = k2.linear_fsa(unique_word_seqs)

    word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas)

    am_scores, lm_scores = compute_am_scores_and_lm_scores(
        lats, word_fsas_with_epsilon_loops, path_to_seq_map, device, batch_size
    )

    token_seqs = k2.ragged.index(lats.labels.contiguous(), paths)
    token_seqs = token_seqs.remove_axis(0)

    token_ids, _ = token_seqs.index(new2old, axis=0)
    token_ids = token_ids.tolist()
    # Now remove repeated tokens and 0s and -1s.
    token_ids = [remove_repeated_and_leq(tokens) for tokens in token_ids]
    return am_scores, lm_scores, token_ids, new2old, path_to_seq_map, seq_to_path_splits
Пример #3
0
def rescore_with_n_best_list(lats: k2.Fsa, G: k2.Fsa,
                             num_paths: int) -> k2.Fsa:
    '''Decode using n-best list with LM rescoring.

    `lats` is a decoding lattice, which has 3 axes. This function first
    extracts `num_paths` paths from `lats` for each sequence using
    `k2.random_paths`. The `am_scores` of these paths are computed.
    For each path, its `lm_scores` is computed using `G` (which is an LM).
    The final `tot_scores` is the sum of `am_scores` and `lm_scores`.
    The path with the greatest `tot_scores` within a sequence is used
    as the decoding output.

    Args:
      lats:
        An FsaVec. It can be the output of `k2.intersect_dense_pruned`.
      G:
        An FsaVec representing the language model (LM). Note that it
        is an FsaVec, but it contains only one Fsa.
      num_paths:
        It is the size `n` in `n-best` list.
    Returns:
      An FsaVec representing the best decoding path for each sequence
      in the lattice.
    '''
    device = lats.device

    assert len(lats.shape) == 3
    assert hasattr(lats, 'aux_labels')
    assert hasattr(lats, 'lm_scores')

    assert G.shape == (1, None, None)
    assert G.device == device
    assert hasattr(G, 'aux_labels') is False

    # First, extract `num_paths` paths for each sequence.
    # paths is a k2.RaggedInt with axes [seq][path][arc_pos]
    paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)

    # word_seqs is a k2.RaggedInt sharing the same shape as `paths`
    # but it contains word IDs. Note that it also contains 0s and -1s.
    # The last entry in each sublist is -1.
    word_seqs = k2.index(lats.aux_labels, paths)

    # Remove epsilons and -1 from word_seqs
    word_seqs = k2.ragged.remove_values_leq(word_seqs, 0)

    # Remove repeated sequences to avoid redundant computation later.
    #
    # unique_word_seqs is still a k2.RaggedInt with 3 axes [seq][path][word]
    # except that there are no repeated paths with the same word_seq
    # within a seq.
    #
    # num_repeats is also a k2.RaggedInt with 2 axes containing the
    # multiplicities of each path.
    # num_repeats.num_elements() == unique_word_seqs.num_elements()
    #
    # Since k2.ragged.unique_sequences will reorder paths within a seq,
    # `new2old` is a 1-D torch.Tensor mapping from the output path index
    # to the input path index.
    # new2old.numel() == unique_word_seqs.num_elements()
    unique_word_seqs, num_repeats, new2old = k2.ragged.unique_sequences(
        word_seqs, need_num_repeats=True, need_new2old_indexes=True)

    seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0)

    # path_to_seq_map is a 1-D torch.Tensor.
    # path_to_seq_map[i] is the seq to which the i-th path
    # belongs.
    path_to_seq_map = seq_to_path_shape.row_ids(1)

    # Remove the seq axis.
    # Now unique_word_seqs has only two axes [path][word]
    unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0)

    # word_fsas is an FsaVec with axes [path][state][arc]
    word_fsas = k2.linear_fsa(unique_word_seqs)

    word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas)

    am_scores = compute_am_scores(lats, word_fsas_with_epsilon_loops,
                                  path_to_seq_map)

    # Now compute lm_scores
    b_to_a_map = torch.zeros_like(path_to_seq_map)
    lm_path_lats = _intersect_device(G,
                                     word_fsas_with_epsilon_loops,
                                     b_to_a_map=b_to_a_map,
                                     sorted_match_a=True)
    lm_path_lats = k2.top_sort(k2.connect(lm_path_lats.to('cpu'))).to(device)
    lm_scores = lm_path_lats.get_tot_scores(True, True)

    tot_scores = am_scores + lm_scores

    # Remember that we used `k2.ragged.unique_sequences` to remove repeated
    # paths to avoid redundant computation in `k2.intersect_device`.
    # Now we use `num_repeats` to correct the scores for each path.
    #
    # NOTE(fangjun): It is commented out as it leads to a worse WER
    # tot_scores = tot_scores * num_repeats.values()

    # TODO(fangjun): We may need to add `k2.RaggedDouble`
    ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape,
                                       tot_scores.to(torch.float32))
    argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)

    # Use k2.index here since argmax_indexes' dtype is torch.int32
    best_path_indexes = k2.index(new2old, argmax_indexes)

    paths = k2.ragged.remove_axis(paths, 0)

    # best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
    best_paths = k2.index(paths, best_path_indexes)

    # labels is a k2.RaggedInt with 2 axes [path][phone_id]
    # Note that it contains -1s.
    labels = k2.index(lats.labels.contiguous(), best_paths)

    labels = k2.ragged.remove_values_eq(labels, -1)

    # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so
    # aux_labels is also a k2.RaggedInt with 2 axes
    aux_labels = k2.index(lats.aux_labels, best_paths.values())

    best_path_fsas = k2.linear_fsa(labels)
    best_path_fsas.aux_labels = aux_labels

    return best_path_fsas
def nbest_decoding(lats: k2.Fsa, num_paths: int):
    '''
    (Ideas of this function are from Dan)

    It implements something like CTC prefix beam search using n-best lists

    The basic idea is to first extra n-best paths from the given lattice,
    build a word seqs from these paths, and compute the total scores
    of these sequences in the log-semiring. The one with the max score
    is used as the decoding output.
    '''

    # First, extract `num_paths` paths for each sequence.
    # paths is a k2.RaggedInt with axes [seq][path][arc_pos]
    paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)

    # word_seqs is a k2.RaggedInt sharing the same shape as `paths`
    # but it contains word IDs. Note that it also contains 0s and -1s.
    # The last entry in each sublist is -1.

    word_seqs = k2.index(lats.aux_labels, paths)
    # Note: the above operation supports also the case when
    # lats.aux_labels is a ragged tensor. In that case,
    # `remove_axis=True` is used inside the pybind11 binding code,
    # so the resulting `word_seqs` still has 3 axes, like `paths`.
    # The 3 axes are [seq][path][word]

    # Remove epsilons and -1 from word_seqs
    word_seqs = k2.ragged.remove_values_leq(word_seqs, 0)

    # Remove repeated sequences to avoid redundant computation later.
    #
    # Since k2.ragged.unique_sequences will reorder paths within a seq,
    # `new2old` is a 1-D torch.Tensor mapping from the output path index
    # to the input path index.
    # new2old.numel() == unique_word_seqs.num_elements()
    unique_word_seqs, _, new2old = k2.ragged.unique_sequences(
        word_seqs, need_num_repeats=False, need_new2old_indexes=True)
    # Note: unique_word_seqs still has the same axes as word_seqs

    seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0)

    # path_to_seq_map is a 1-D torch.Tensor.
    # path_to_seq_map[i] is the seq to which the i-th path
    # belongs.
    path_to_seq_map = seq_to_path_shape.row_ids(1)

    # Remove the seq axis.
    # Now unique_word_seqs has only two axes [path][word]
    unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0)

    # word_fsas is an FsaVec with axes [path][state][arc]
    word_fsas = k2.linear_fsa(unique_word_seqs)

    word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas)

    # lats has phone IDs as labels and word IDs as aux_labels.
    # inv_lats has word IDs as labels and phone IDs as aux_labels
    inv_lats = k2.invert(lats)
    inv_lats = k2.arc_sort(inv_lats)  # no-op if inv_lats is already arc-sorted

    path_lats = k2.intersect_device(inv_lats,
                                    word_fsas_with_epsilon_loops,
                                    b_to_a_map=path_to_seq_map,
                                    sorted_match_a=True)
    # path_lats has word IDs as labels and phone IDs as aux_labels

    path_lats = k2.top_sort(k2.connect(path_lats.to('cpu')).to(lats.device))

    tot_scores = path_lats.get_tot_scores(True, True)
    # RaggedFloat currently supports float32 only.
    # We may bind Ragged<double> as RaggedDouble if needed.
    ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape,
                                       tot_scores.to(torch.float32))

    argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)

    # Since we invoked `k2.ragged.unique_sequences`, which reorders
    # the index from `paths`, we use `new2old`
    # here to convert argmax_indexes to the indexes into `paths`.
    #
    # Use k2.index here since argmax_indexes' dtype is torch.int32
    best_path_indexes = k2.index(new2old, argmax_indexes)

    paths_2axes = k2.ragged.remove_axis(paths, 0)

    # best_paths is a k2.RaggedInt with 2 axes [path][arc_pos]
    best_paths = k2.index(paths_2axes, best_path_indexes)

    # labels is a k2.RaggedInt with 2 axes [path][phone_id]
    # Note that it contains -1s.
    labels = k2.index(lats.labels.contiguous(), best_paths)

    labels = k2.ragged.remove_values_eq(labels, -1)

    # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so
    # aux_labels is also a k2.RaggedInt with 2 axes
    aux_labels = k2.index(lats.aux_labels, best_paths.values())

    best_path_fsas = k2.linear_fsa(labels)
    best_path_fsas.aux_labels = aux_labels

    return best_path_fsas