Exemple #1
def _intersect_device(
    a_fsas: k2.Fsa,
    b_fsas: k2.Fsa,
    b_to_a_map: torch.Tensor,
    sorted_match_a: bool,
    batch_size: int = 500,
    """Wrap k2.intersect_device

    This is a wrapper of k2.intersect_device and its purpose is to split
    b_fsas into several batches and process each batch separately to avoid
    CUDA OOM error.
    The arguments and return value of this function are the same as

    NOTE: You can decrease batch_size in case of CUDA out of memory error.
    num_fsas = b_fsas.shape[0]
    if num_fsas <= batch_size:
        return k2.intersect_device(
            a_fsas, b_fsas, b_to_a_map=b_to_a_map, sorted_match_a=sorted_match_a

    num_batches = int(math.ceil(float(num_fsas) / batch_size))
    splits = []
    for i in range(num_batches):
        start = i * batch_size
        end = min(start + batch_size, num_fsas)
        splits.append((start, end))

    ans = []
    for start, end in splits:
        indexes = torch.arange(start, end).to(b_to_a_map)

        fsas = k2.index_fsa(b_fsas, indexes)
        b_to_a = k2.index_select(b_to_a_map, indexes)
        path_lats = k2.intersect_device(
            a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a

    return k2.cat(ans)
Exemple #2
    def intersect(self, lats: Fsa) -> 'Nbest':
        '''Intersect this Nbest object with a lattice and get 1-best
        path from the resulting FsaVec.

          We assume FSAs in `self.fsa` don't have epsilon self-loops.
          We also assume `self.fsa.labels` and `lats.labels` are token IDs.

            An FsaVec. It can be the return value of
          Return a new Nbest. This new Nbest shares the same shape with `self`,
          while its `fsa` is the 1-best path from intersecting `self.fsa` and
        assert self.fsa.device == lats.device, \
                f'{self.fsa.device} vs {lats.device}'
        assert len(lats.shape) == 3, f'{lats.shape}'
        assert lats.arcs.dim0() == self.shape.dim0(), \
                f'{lats.arcs.dim0()} vs {self.shape.dim0()}'

        lats = k2.arc_sort(lats)  # no-op if lats is already arc sorted

        fsas_with_epsilon_loops = k2.add_epsilon_self_loops(self.fsa)

        path_to_seq_map = self.shape.row_ids(1)

        ans_lats = k2.intersect_device(a_fsas=lats,

        one_best = k2.shortest_path(ans_lats, use_double_scores=True)

        one_best = k2.remove_epsilon(one_best)

        return Nbest(fsa=one_best, shape=self.shape)
    def test(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():

        for device in devices:
            for use_identity_map, sorted_match_a in [(True, True),
                                                     (False, True),
                                                     (True, False),
                                                     (False, False)]:
                # recognizes (0|1)(0|2)
                s1 = '''
                    0 1 0 0.1
                    0 1 1 0.2
                    1 2 0 0.4
                    1 2 2 0.3
                    2 3 -1 0.5

                # recognizes 02*
                s2 = '''
                    0 1 0 1
                    1 1 2 2
                    1 2 -1 3

                # recognizes 1*0
                s3 = '''
                    0 0 1 10
                    0 1 0 20
                    1 2 -1 30
                a_fsa = k2.Fsa.from_str(s1).to(device)
                b_fsa_1 = k2.Fsa.from_str(s2).to(device)
                b_fsa_2 = k2.Fsa.from_str(s3).to(device)


                b_fsas = k2.create_fsa_vec([b_fsa_1, b_fsa_2])
                if use_identity_map:
                    a_fsas = k2.create_fsa_vec([a_fsa, a_fsa])
                    b_to_a_map = torch.tensor([0, 1],
                    a_fsas = k2.create_fsa_vec([a_fsa])
                    b_to_a_map = torch.tensor([0, 0],

                c_fsas = k2.intersect_device(a_fsas, b_fsas, b_to_a_map,
                assert c_fsas.shape == (2, None, None)
                c_fsas = k2.connect(c_fsas.to('cpu'))
                # c_fsas[0] recognizes: 02
                # c_fsas[1] recognizes: 10

                actual_str_0 = k2.to_str(c_fsas[0])
                expected_str_0 = '\n'.join(
                    ['0 1 0 1.1', '1 2 2 2.3', '2 3 -1 3.5', '3'])
                assert actual_str_0.strip() == expected_str_0

                actual_str_1 = k2.to_str(c_fsas[1])
                expected_str_1 = '\n'.join(
                    ['0 1 1 10.2', '1 2 0 20.4', '2 3 -1 30.5', '3'])
                assert actual_str_1.strip() == expected_str_1

                loss = c_fsas.scores.sum()
                assert torch.allclose(
                    torch.tensor([-1, -1, -1, -1, -2]).to(a_fsa.grad))
                assert torch.allclose(
                    torch.tensor([-1, -1, -1]).to(b_fsa_1.grad))
                assert torch.allclose(
                    torch.tensor([-1, -1, -1]).to(b_fsa_2.grad))
Exemple #4
def rescore_with_whole_lattice(lats: k2.Fsa,
                               G_with_epsilon_loops: k2.Fsa) -> k2.Fsa:
    '''Use whole lattice to rescore.

        An FsaVec It can be the output of `k2.intersect_dense_pruned`.
        An FsaVec representing the language model (LM). Note that it
        is an FsaVec, but it contains only one Fsa.
    assert len(lats.shape) == 3
    assert hasattr(lats, 'lm_scores')
    assert G_with_epsilon_loops.shape == (1, None, None)

    device = lats.device
    lats.scores = lats.scores - lats.lm_scores
    # Now, lats.scores contains only am_scores

    # inverted_lats has word IDs as labels.
    # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt
    inverted_lats = k2.invert(lats)
    num_seqs = lats.shape[0]
    inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops(inverted_lats)

    b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
    except RuntimeError as e:
        print(f'Caught exception:\n{e}\n')
        print(f'Number of FSAs: {inverted_lats.shape[0]}')
        print('num_arcs before pruning: ',

        # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here
        # to avoid OOM. We may need to fine tune it.
        inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True)
        inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops(
        print('num_arcs after pruning: ',

        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,

    rescoring_lats = k2.top_sort(k2.connect(
    inverted_rescoring_lats = k2.invert(rescoring_lats)
    # inverted rescoring_lats has phone IDs as labels
    # and word IDs as aux_labels.

    inverted_rescoring_lats = k2.remove_epsilon_self_loops(
    best_paths = k2.shortest_path(inverted_rescoring_lats,
    return best_paths
Exemple #5
def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa,
                               lm_scale_list: List[float]
                              ) -> Dict[str, k2.Fsa]:
    '''Use whole lattice to rescore.

        An FsaVec It can be the output of `k2.intersect_dense_pruned`.
        An FsaVec representing the language model (LM). Note that it
        is an FsaVec, but it contains only one Fsa.
        A list containing lm_scale values.
      A dict of FsaVec, whose key is a lm_scale and the value represents the
      best decoding path for each sequence in the lattice.
    assert len(lats.shape) == 3
    assert hasattr(lats, 'lm_scores')
    assert G_with_epsilon_loops.shape == (1, None, None)

    device = lats.device
    lats.scores = lats.scores - lats.lm_scores
    # We will use lm_scores from G, so remove lats.lm_scores here
    del lats.lm_scores
    assert hasattr(lats, 'lm_scores') is False

    #  lats.scores = scores / lm_scale
    # Now, lats.scores contains only am_scores

    # inverted_lats has word IDs as labels.
    # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt
    inverted_lats = k2.invert(lats)
    num_seqs = lats.shape[0]

    b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
    except RuntimeError as e:
        print(f'Caught exception:\n{e}\n')
        print(f'Number of FSAs: {inverted_lats.shape[0]}')
        print('num_arcs before pruning: ', inverted_lats.arcs.num_elements())

        # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here
        # to avoid OOM. We may need to fine tune it.
        inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True)
        print('num_arcs after pruning: ', inverted_lats.arcs.num_elements())

        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,

    rescoring_lats = k2.top_sort(k2.connect(rescoring_lats.to('cpu')).to(device))

    # inv_lats has phone IDs as labels
    # and word IDs as aux_labels.
    inv_lats = k2.invert(rescoring_lats)

    ans = dict()
    # The following implements
    # scores = (scores - lm_scores)/lm_scale + lm_scores
    #        = scores/lm_scale + lm_scores*(1 - 1/lm_scale)
    saved_scores = inv_lats.scores.clone()
    for lm_scale in lm_scale_list:
        am_scores = saved_scores - inv_lats.lm_scores
        am_scores /= lm_scale
        inv_lats.scores = am_scores + inv_lats.lm_scores

        best_paths = k2.shortest_path(inv_lats, use_double_scores=True)
        key = f'lm_scale_{lm_scale}'
        ans[key] = best_paths
    return ans
def nbest_decoding(lats: k2.Fsa, num_paths: int):
    (Ideas of this function are from Dan)

    It implements something like CTC prefix beam search using n-best lists

    The basic idea is to first extra n-best paths from the given lattice,
    build a word seqs from these paths, and compute the total scores
    of these sequences in the log-semiring. The one with the max score
    is used as the decoding output.

    # First, extract `num_paths` paths for each sequence.
    # paths is a k2.RaggedInt with axes [seq][path][arc_pos]
    paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)

    # word_seqs is a k2.RaggedInt sharing the same shape as `paths`
    # but it contains word IDs. Note that it also contains 0s and -1s.
    # The last entry in each sublist is -1.

    word_seqs = k2.index(lats.aux_labels, paths)
    # Note: the above operation supports also the case when
    # lats.aux_labels is a ragged tensor. In that case,
    # `remove_axis=True` is used inside the pybind11 binding code,
    # so the resulting `word_seqs` still has 3 axes, like `paths`.
    # The 3 axes are [seq][path][word]

    # Remove epsilons and -1 from word_seqs
    word_seqs = k2.ragged.remove_values_leq(word_seqs, 0)

    # Remove repeated sequences to avoid redundant computation later.
    # Since k2.ragged.unique_sequences will reorder paths within a seq,
    # `new2old` is a 1-D torch.Tensor mapping from the output path index
    # to the input path index.
    # new2old.numel() == unique_word_seqs.num_elements()
    unique_word_seqs, _, new2old = k2.ragged.unique_sequences(
        word_seqs, need_num_repeats=False, need_new2old_indexes=True)
    # Note: unique_word_seqs still has the same axes as word_seqs

    seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0)

    # path_to_seq_map is a 1-D torch.Tensor.
    # path_to_seq_map[i] is the seq to which the i-th path
    # belongs.
    path_to_seq_map = seq_to_path_shape.row_ids(1)

    # Remove the seq axis.
    # Now unique_word_seqs has only two axes [path][word]
    unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0)

    # word_fsas is an FsaVec with axes [path][state][arc]
    word_fsas = k2.linear_fsa(unique_word_seqs)

    word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas)

    # lats has phone IDs as labels and word IDs as aux_labels.
    # inv_lats has word IDs as labels and phone IDs as aux_labels
    inv_lats = k2.invert(lats)
    inv_lats = k2.arc_sort(inv_lats)  # no-op if inv_lats is already arc-sorted

    path_lats = k2.intersect_device(inv_lats,
    # path_lats has word IDs as labels and phone IDs as aux_labels

    path_lats = k2.top_sort(k2.connect(path_lats.to('cpu')).to(lats.device))

    tot_scores = path_lats.get_tot_scores(True, True)
    # RaggedFloat currently supports float32 only.
    # We may bind Ragged<double> as RaggedDouble if needed.
    ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape,

    argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)

    # Since we invoked `k2.ragged.unique_sequences`, which reorders
    # the index from `paths`, we use `new2old`
    # here to convert argmax_indexes to the indexes into `paths`.
    # Use k2.index here since argmax_indexes' dtype is torch.int32
    best_path_indexes = k2.index(new2old, argmax_indexes)

    paths_2axes = k2.ragged.remove_axis(paths, 0)

    # best_paths is a k2.RaggedInt with 2 axes [path][arc_pos]
    best_paths = k2.index(paths_2axes, best_path_indexes)

    # labels is a k2.RaggedInt with 2 axes [path][phone_id]
    # Note that it contains -1s.
    labels = k2.index(lats.labels.contiguous(), best_paths)

    labels = k2.ragged.remove_values_eq(labels, -1)

    # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so
    # aux_labels is also a k2.RaggedInt with 2 axes
    aux_labels = k2.index(lats.aux_labels, best_paths.values())

    best_path_fsas = k2.linear_fsa(labels)
    best_path_fsas.aux_labels = aux_labels

    return best_path_fsas
Exemple #7
def levenshtein_alignment(
        refs: Fsa,
        hyps: Fsa,
        hyp_to_ref_map: torch.Tensor,
        sorted_match_ref: bool = False,
) -> Fsa:
    '''Get the levenshtein alignment of two FsaVecs

    This function supports both CPU and GPU. But it is very slow on CPU.

        An FsaVec (must have 3 axes, i.e., `len(refs.shape) == 3`. It is the
        output Fsa of the :func:`levenshtein_graph`.
        An FsaVec (must have 3 axes) on the same device as `refs`. It is the
        output Fsa of the :func:`levenshtein_graph`.
        A 1-D torch.Tensor with dtype torch.int32 on the same device
        as `refs`. Map from FSA-id in `hpys` to the corresponding
        FSA-id in `refs` that we want to get levenshtein alignment with.
        E.g. might be an identity map, or all-to-zero, or something the
        user chooses.

            - `hyp_to_ref_map.shape[0] == hyps.shape[0]`
            - `0 <= hyp_to_ref_map[i] < refs.shape[0]`
        If true, the arcs of refs must be sorted by label (checked by
        calling code via properties), and we'll use a matching approach
        that requires this.

      Returns an FsaVec containing the alignment information and satisfing
      `ans.Dim0() == hyps.Dim0()`. Two attributes named `ref_labels` and
      `hyp_labels` will be added to the returned FsaVec. `ref_labels` contains
      the aligned sequences of refs and `hyp_labels` contains the aligned
      sequences of hyps. You can get the levenshtein distance by calling
      `get_tot_scores` on the returned FsaVec.

      >>> hyps = k2.levenshtein_graph([[1, 2, 3], [1, 3, 3, 2]])
      >>> refs = k2.levenshtein_graph([[1, 2, 4]])
      >>> alignment = k2.levenshtein_alignment(
              refs, hyps,
              hyp_to_ref_map=torch.tensor([0, 0], dtype=torch.int32),
      >>> alignment.labels
      tensor([ 1,  2,  0, -1,  1,  0,  0,  0, -1], dtype=torch.int32)
      >>> alignment.ref_labels
      tensor([ 1,  2,  4, -1,  1,  2,  4,  0, -1], dtype=torch.int32)
      >>> alignment.hyp_labels
      tensor([ 1,  2,  3, -1,  1,  3,  3,  2, -1], dtype=torch.int32)
      >>> -alignment.get_tot_scores(
              use_double_scores=False, log_semiring=False))
      tensor([1., 3.])
    assert hasattr(refs, "aux_labels")
    assert hasattr(hyps, "aux_labels")

    hyps.rename_tensor_attribute_("aux_labels", "hyp_labels")

    lattice = k2.intersect_device(
        refs, hyps, b_to_a_map=hyp_to_ref_map, sorted_match_a=sorted_match_ref)
    lattice = k2.remove_epsilon_self_loops(lattice)

    alignment = k2.shortest_path(lattice, use_double_scores=True).invert_()
    alignment.rename_tensor_attribute_("labels", "ref_labels")
    alignment.rename_tensor_attribute_("aux_labels", "labels")

    alignment.scores -= getattr(
        alignment, "__ins_del_score_offset_internal_attr_")

    return alignment
Exemple #8
def whole_lattice_rescoring(lats: Fsa, G_with_epsilon_loops: Fsa) -> Fsa:
    '''Rescore the 1st pass lattice with an LM.

    In general, the G in HLG used to obtain `lats` is a 3-gram LM.
    This function replaces the 3-gram LM in `lats` with a 4-gram LM.

        The decoding lattice from the 1st pass. We assume it is the result
        of intersecting HLG with the network output.
        An LM. It is usually a 4-gram LM with epsilon self-loops.
        It should be arc sorted.
      Return a new lattice rescored with a given G.
    assert len(lats.shape) == 3, f'{lats.shape}'
    assert hasattr(lats, 'lm_scores')
    assert G_with_epsilon_loops.shape == (1, None, None), \

    device = lats.device
    lats.scores = lats.scores - lats.lm_scores
    # Now lats contains only acoustic scores

    # We will use lm_scores from the given G, so remove lats.lm_scores here
    del lats.lm_scores
    assert hasattr(lats, 'lm_scores') is False

    # inverted_lats has word IDs as labels.
    # Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt
    # if lats.aux_labels is a ragged tensor
    inverted_lats = k2.invert(lats)
    num_seqs = lats.shape[0]

    b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)

    while True:
            rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
        except RuntimeError as e:
            logging.info(f'Caught exception:\n{e}\n')
            # Usually, this is an OOM exception. We reduce
            # the size of the lattice and redo k2.intersect_device()

            # NOTE(fangjun): The choice of the threshold 1e-5 is arbitrary here
            # to avoid OOM. We may need to fine tune it.
            logging.info(f'num_arcs before: {inverted_lats.num_arcs}')
            inverted_lats = k2.prune_on_arc_post(inverted_lats, 1e-5, True)
            logging.info(f'num_arcs after: {inverted_lats.num_arcs}')

    rescoring_lats = k2.top_sort(k2.connect(rescoring_lats))

    # inv_rescoring_lats has token IDs as labels
    # and word IDs as aux_labels.
    inv_rescoring_lats = k2.invert(rescoring_lats)
    return inv_rescoring_lats