Пример #1
0
    def test(self):
        s0 = '''
            0 1 1 0.1
            0 2 2 0.2
            1 2 3 0.3
            2 3 -1 0.4
            3
        '''
        s1 = '''
            0 1 -1 0.5
            1
        '''
        s2 = '''
            0 2 1 0.6
            0 1 2 0.7
            1 3 -1 0.8
            2 1 3 0.9
            3
        '''
        fsa0 = k2.Fsa.from_str(s0).requires_grad_(True)
        fsa1 = k2.Fsa.from_str(s1).requires_grad_(True)
        fsa2 = k2.Fsa.from_str(s2).requires_grad_(True)

        fsa_vec = k2.create_fsa_vec([fsa0, fsa1, fsa2])

        new_fsa21 = k2.index(fsa_vec, torch.tensor([2, 1], dtype=torch.int32))
        assert new_fsa21.shape == (2, None, None)
        assert torch.allclose(
            new_fsa21.arcs.values()[:, :3],
            torch.tensor([
                # fsa 2
                [0, 2, 1],
                [0, 1, 2],
                [1, 3, -1],
                [2, 1, 3],
                # fsa 1
                [0, 1, -1]
            ]).to(torch.int32))

        scale = torch.arange(new_fsa21.scores.numel())
        (new_fsa21.scores * scale).sum().backward()
        assert torch.allclose(fsa0.scores.grad, torch.tensor([0., 0, 0, 0]))
        assert torch.allclose(fsa1.scores.grad, torch.tensor([4.]))
        assert torch.allclose(fsa2.scores.grad, torch.tensor([0., 1., 2., 3.]))

        # now select only a single FSA
        fsa0.scores.grad = None
        fsa1.scores.grad = None
        fsa2.scores.grad = None

        new_fsa0 = k2.index(fsa_vec, torch.tensor([0], dtype=torch.int32))
        assert new_fsa0.shape == (1, None, None)

        scale = torch.arange(new_fsa0.scores.numel())
        (new_fsa0.scores * scale).sum().backward()
        assert torch.allclose(fsa0.scores.grad, torch.tensor([0., 1., 2., 3.]))
        assert torch.allclose(fsa1.scores.grad, torch.tensor([0.]))
        assert torch.allclose(fsa2.scores.grad, torch.tensor([0., 0., 0., 0.]))
Пример #2
0
    def test(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))
        for device in devices:
            src_row_splits = torch.tensor([0, 2, 3, 3, 6],
                                          dtype=torch.int32,
                                          device=device)
            src_shape = k2.ragged.create_ragged_shape2(src_row_splits, None, 6)
            src_values = torch.tensor([1, 2, 3, 4, 5, 6],
                                      dtype=torch.int32,
                                      device=device)
            src = k2.RaggedInt(src_shape, src_values)

            # index with ragged int
            index_row_splits = torch.tensor([0, 2, 2, 3, 7],
                                            dtype=torch.int32,
                                            device=device)
            index_shape = k2.ragged.create_ragged_shape2(
                index_row_splits, None, 7)
            index_values = torch.tensor([0, 3, 2, 1, 2, 1, 0],
                                        dtype=torch.int32,
                                        device=device)
            ragged_index = k2.RaggedInt(index_shape, index_values)
            ans = k2.index(src, ragged_index)
            expected_row_splits = torch.tensor([0, 5, 5, 5, 9],
                                               dtype=torch.int32,
                                               device=device)
            self.assertTrue(
                torch.allclose(ans.row_splits(1), expected_row_splits))
            expected_values = torch.tensor([1, 2, 4, 5, 6, 3, 3, 1, 2],
                                           dtype=torch.int32,
                                           device=device)
            self.assertTrue(torch.allclose(ans.values(), expected_values))

            # index with tensor
            tensor_index = torch.tensor([0, 3, 2, 1, 2, 1],
                                        dtype=torch.int32,
                                        device=device)
            ans = k2.index(src, tensor_index)
            expected_row_splits = torch.tensor([0, 2, 5, 5, 6, 6, 7],
                                               dtype=torch.int32,
                                               device=device)
            self.assertTrue(
                torch.allclose(ans.row_splits(1), expected_row_splits))
            expected_values = torch.tensor([1, 2, 4, 5, 6, 3, 3],
                                           dtype=torch.int32,
                                           device=device)
            self.assertTrue(torch.allclose(ans.values(), expected_values))
Пример #3
0
def fsa_from_unary_function_ragged(src: Fsa, dest_arcs: _k2.RaggedArc,
                                   arc_map: _k2.RaggedInt) -> Fsa:
    '''Create an Fsa object, including autograd logic and propagating
    properties from the source FSA.

    This is intended to be called from unary functions on FSAs where the arc_map
    is an instance of _k2.RaggedInt.

    Args:
      src:
        The source Fsa, i.e. the arg to the unary function.
      dest_arcs:
        The raw output of the unary function, as output by whatever C++
        algorithm we used.
      arc_map:
        A map from arcs in `dest_arcs` to the corresponding arc-index in `src`,
        or -1 if the arc had no source arc (e.g. :func:`remove_epsilon`).
    Returns:
      Returns the resulting Fsa, with properties propagated appropriately, and
      autograd handled.
    '''
    dest = Fsa(dest_arcs)

    for name, value in src.named_tensor_attr(include_scores=False):
        setattr(dest, name, k2.index(value, arc_map))

    for name, value in src.named_non_tensor_attr():
        setattr(dest, name, value)

    k2.autograd_utils.phantom_index_and_sum_scores(dest, src.scores, arc_map)

    return dest
Пример #4
0
    def forward(self, log_probs: torch.Tensor, targets: torch.Tensor,
                input_lengths: torch.Tensor,
                target_lengths: torch.Tensor) -> torch.Tensor:

        log_probs = log_probs.permute(1, 0, 2).cpu(
        )  # now log_probs is [N, T, C]  batchSize x seqLength x alphabet_size
        supervision_segments = torch.stack(
            (torch.tensor(range(input_lengths.shape[0])),
             torch.zeros(input_lengths.shape[0]), input_lengths),
            1).to(torch.int32)
        indices = torch.argsort(supervision_segments[:, 2], descending=True)
        supervision_segments = supervision_segments[indices]

        dense_fsa_vec = k2.DenseFsaVec(log_probs, supervision_segments)
        decoding_graph = self.graph_compiler.compile(targets.cpu(),
                                                     target_lengths)
        decoding_graph = k2.index(decoding_graph,
                                  indices.to(torch.int32)).to(log_probs.device)

        target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0)
        tot_scores = k2.get_tot_scores(target_graph,
                                       log_semiring=True,
                                       use_double_scores=True)
        (tot_score, tot_frames,
         all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                                   supervision_segments[:, 2])
        return -tot_score
Пример #5
0
    def forward(ctx, fsas: Fsa, out_fsa: List[Fsa],
                unused_fsas_scores: torch.Tensor) -> torch.Tensor:
        '''Compute the union of all fsas in a FsaVec.

        Args:
          fsas:
            The input FsaVec. Caution: We require that each fsa in the FsaVec
            is non-empty (i.e., with at least two states).
          out_fsa:
            A list containing one entry. Since this function can only return
            values of type `torch.Tensor`, we return the union result in the
            list.
          unused_fsas_scores:
            It is the same as `fsas.scores`, whose sole purpose is for autograd.
            It is not used in this function.
        '''
        need_arc_map = True
        ragged_arc, arc_map = _k2.union(fsas.arcs, need_arc_map)
        out_fsa[0] = Fsa(ragged_arc)

        for name, value in fsas.named_tensor_attr(include_scores=False):
            value = k2.index(value, arc_map)
            setattr(out_fsa[0], name, value)

        for name, value in fsas.named_non_tensor_attr():
            setattr(out_fsa[0], name, value)
        ctx.arc_map = arc_map
        ctx.save_for_backward(unused_fsas_scores)

        return out_fsa[0].scores  # the return value will be discarded
Пример #6
0
def expand_ragged_attributes(
        fsas: Fsa,
        ret_arc_map: bool = False
) -> Union[Fsa, Tuple[Fsa, torch.Tensor]]:  # noqa
    '''
    Turn ragged labels attached to this FSA into linear (Tensor) labels,
    expanding arcs into sequences of arcs as necessary to achieve this.
    Supports autograd.  If `fsas` had no ragged attributes, returns `fsas`
    itself.

         ret_arc_map:  if true, will return a pair (new_fsas, arc_map)
              with `arc_map` a tensor of int32 that maps from arcs in the
              result to arcs in `fsas`, with -1's for newly created arcs.
              If false, just returns new_fsas.
    '''
    ragged_attribute_tensors = []
    ragged_attribute_names = []
    for name, value in fsas.named_tensor_attr(include_scores=False):
        if isinstance(value, k2.RaggedInt):
            ragged_attribute_tensors.append(value)
            ragged_attribute_names.append(name)

    if len(ragged_attribute_tensors) == 0:
        if ret_arc_map:
            arc_map = torch.arange(fsas.num_arcs,
                                   dtype=torch.int32,
                                   device=fsas.device)
            return (fsas, arc_map)
        else:
            return fsas

    (dest_arcs, dest_labels,
     arc_map) = _k2.expand_arcs(fsas.arcs, ragged_attribute_tensors)

    # The rest of this function is a modified version of
    # `fsa_from_unary_function_tensor()`.
    dest = Fsa(dest_arcs)

    # Handle the non-ragged attributes
    for name, value in fsas.named_tensor_attr(include_scores=False):
        if not isinstance(value, k2.RaggedInt):
            setattr(dest, name, k2.index(value, arc_map))

    # Handle the attributes that were ragged but are now linear
    for name, value in zip(ragged_attribute_names, dest_labels):
        setattr(dest, name, value)

    # Copy non-tensor attributes
    for name, value in fsas.named_non_tensor_attr():
        setattr(dest, name, value)

    # make sure autograd works on the scores
    k2.autograd_utils.phantom_index_select_scores(dest, fsas.scores, arc_map)

    if ret_arc_map:
        return dest, arc_map
    else:
        return dest
Пример #7
0
def ctc_graph(symbols: Union[List[List[int]], k2.RaggedInt],
              modified: bool = False,
              device: Optional[Union[torch.device, str]] = None) -> Fsa:
    '''Construct ctc graphs from symbols.

    Note:
      The scores of arcs in the returned FSA are all 0.

    Args:
      symbols:
        It can be one of the following types:

            - A list of list-of-integers, e..g, `[ [1, 2], [1, 2, 3] ]`
            - An instance of :class:`k2.RaggedInt`. Must have `num_axes() == 2`.

      standard:
        Option to specify the type of CTC topology: "standard" or "simplified",
        where the "standard" one makes the blank mandatory between a pair of
        identical symbols. Default True.
      device:
        Optional. It can be either a string (e.g., 'cpu', 'cuda:0') or a
        torch.device.
        If it is None, then the returned FSA is on CPU. It has to be None
        if `symbols` is an instance of :class:`k2.RaggedInt`, the returned
        FSA will on the same device as `k2.RaggedInt`.

    Returns:
        An FsaVec containing the returned ctc graphs, with "Dim0()" the same as
        "len(symbols)"(List[List[int]]) or "Dim0()"(k2.RaggedInt)
    '''
    if device is not None:
        device = torch.device(device)
        if device.type == 'cpu':
            gpu_id = -1
        else:
            assert device.type == 'cuda'
            gpu_id = getattr(device, 'index', 0)
    else:
        gpu_id = -1

    symbol_values = None
    if isinstance(symbols, k2.RaggedInt):
        assert device is None
        assert symbols.num_axes() == 2
        symbol_values = symbols.values()
    else:
        symbol_values = torch.tensor(
            [it for symbol in symbols for it in symbol], dtype=torch.int32,
            device=device)

    need_arc_map = True
    ragged_arc, arc_map = _k2.ctc_graph(symbols, gpu_id,
                                        modified, need_arc_map)
    aux_labels = k2.index(symbol_values, arc_map)
    fsa = Fsa(ragged_arc, aux_labels=aux_labels)
    return fsa
Пример #8
0
def generate_nbest_list(lats: Fsa, num_paths: int) -> Nbest:
    '''Generate an n-best list from a lattice.

    Args:
      lats:
        The decoding lattice from the first pass after LM rescoring.
        lats is an FsaVec. It can be the return value of
        :func:`whole_lattice_rescoring`
      num_paths:
        Size of n for n-best list. CAUTION: After removing paths
        that represent the same token sequences, the number of paths
        in different sequences may not be equal.
    Return:
      Return an Nbest object. Note the returned FSAs don't have epsilon
      self-loops.
    '''
    assert len(lats.shape) == 3

    # CAUTION: We use `phones` instead of `tokens` here because
    # :func:`compile_HLG` uses `phones`
    #
    # Note: compile_HLG is from k2-fsa/snowfall
    assert hasattr(lats, 'phones')

    assert not hasattr(lats, 'tokens')
    lats.tokens = lats.phones
    # we use tokens instead of phones in the following code

    # First, extract `num_paths` paths for each sequence.
    # paths is a k2.RaggedInt with axes [seq][path][arc_pos]
    paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)

    # token_seqs is a k2.RaggedInt sharing the same shape as `paths`
    # but it contains token IDs. Note that it also contains 0s and -1s.
    # The last entry in each sublist is -1.
    # Its axes are [seq][path][token_id]
    token_seqs = k2.index(lats.tokens, paths)

    # Remove epsilons (0s) and -1 from token_seqs
    token_seqs = k2.ragged.remove_values_leq(token_seqs, 0)

    # unique_token_seqs is still a k2.RaggedInt with axes [seq][path]token_id].
    # But then number of pathsin each sequence may be different.
    unique_token_seqs, _, _ = k2.ragged.unique_sequences(
        word_seqs, need_num_repeats=False, need_new2old_indexes=False)

    seq_to_path_shape = k2.ragged.get_layer(unique_token_seqs.shape(), 0)

    # Remove the seq axis.
    # Now unique_token_seqs has only two axes [path][token_id]
    unique_token_seqs = k2.ragged.remove_axis(unique_token_seqs, 0)

    token_fsas = k2.linear_fsa(unique_token_seqs)

    return Nbest(fsa=token_fsas, shape=seq_to_path_shape)
Пример #9
0
def _intersect_device(a_fsas: k2.Fsa, b_fsas: k2.Fsa, b_to_a_map: torch.Tensor,
                      sorted_match_a: bool):
    '''This is a wrapper of k2.intersect_device and its purpose is to split
    b_fsas into several batches and process each batch separately to avoid
    CUDA OOM error.

    The arguments and return value of this function are the same as
    k2.intersect_device.
    '''
    # NOTE: You can decrease batch_size in case of CUDA out of memory error.
    batch_size = 500
    num_fsas = b_fsas.shape[0]
    if num_fsas <= batch_size:
        return k2.intersect_device(a_fsas,
                                   b_fsas,
                                   b_to_a_map=b_to_a_map,
                                   sorted_match_a=sorted_match_a)

    num_batches = int(math.ceil(float(num_fsas) / batch_size))
    splits = []
    for i in range(num_batches):
        start = i * batch_size
        end = min(start + batch_size, num_fsas)
        splits.append((start, end))

    ans = []
    for start, end in splits:
        indexes = torch.arange(start, end).to(b_to_a_map)

        fsas = k2.index(b_fsas, indexes)
        b_to_a = k2.index(b_to_a_map, indexes)
        path_lats = k2.intersect_device(a_fsas,
                                        fsas,
                                        b_to_a_map=b_to_a,
                                        sorted_match_a=sorted_match_a)
        ans.append(path_lats)

    return k2.cat(ans)
Пример #10
0
    def test_sort_sublist_descending(self):
        for device in self.devices:
            src = k2.RaggedInt('[ [3 2] [] [1 5 2]]').to(device)
            src_clone = src.clone()
            new2old = k2.ragged.sort_sublist(src,
                                             descending=True,
                                             need_new2old_indexes=True)
            sorted_src = k2.RaggedInt('[[3 2] [] [5 2 1]]')
            expected_new2old = torch.tensor([0, 1, 3, 4, 2],
                                            device=device,
                                            dtype=torch.int32)
            assert str(src) == str(sorted_src)
            assert torch.all(torch.eq(new2old, expected_new2old))

            expected_sorted = k2.index(src_clone.values(), new2old)
            sorted = src.values()
            assert torch.all(torch.eq(expected_sorted, sorted))
Пример #11
0
    def test(self):
        for device in self.devices:
            src = torch.tensor([1, 2, 3, 4, 5, 6, 7],
                               dtype=torch.int32,
                               device=device)
            index_row_splits = torch.tensor([0, 2, 2, 3, 7],
                                            dtype=torch.int32,
                                            device=device)
            index_shape = k2.ragged.create_ragged_shape2(
                index_row_splits, None, 7)
            index_values = torch.tensor([0, 3, 2, 3, 5, 1, 3],
                                        dtype=torch.int32,
                                        device=device)
            ragged_index = k2.RaggedInt(index_shape, index_values)

            ans = k2.index(src, ragged_index)
            self.assertTrue(torch.allclose(ans.row_splits(1),
                                           index_row_splits))
            expected_values = torch.tensor([1, 4, 3, 4, 6, 2, 4],
                                           dtype=torch.int32,
                                           device=device)
            self.assertTrue(torch.allclose(ans.values(), expected_values))
Пример #12
0
def _compute_mmi_loss_exact_optimized(
        nnet_output: torch.Tensor,
        texts: List[str],
        supervision_segments: torch.Tensor,
        graph_compiler: MmiTrainingGraphCompiler,
        P: k2.Fsa,
        den_scale: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    The function name contains `exact`, which means it uses a version of
    intersection without pruning.

    `optimized` in the function name means this function is optimized
    in that it calls k2.intersect_dense only once

    Note:
      It is faster at the cost of using more memory.

    Args:
      nnet_output:
        A 3-D tensor of shape [N, T, C]
      texts:
        The transcript. Each element consists of space(s) separated words.
      supervision_segments:
        A 2-D tensor that will be passed to :func:`k2.DenseFsaVec`.
      graph_compiler:
        Used to build num_graphs and den_graphs
      P:
        Represents a bigram Fsa.
      den_scale:
        The scale applied to the denominator tot_scores.
    '''
    num_graphs, den_graphs = graph_compiler.compile(texts,
                                                    P,
                                                    replicate_den=False)

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

    device = num_graphs.device

    num_fsas = num_graphs.shape[0]
    assert dense_fsa_vec.dim0() == num_fsas

    assert den_graphs.shape[0] == 1

    # the aux_labels of num_graphs is k2.RaggedInt
    # but it is torch.Tensor for den_graphs.
    #
    # The following converts den_graphs.aux_labels
    # from torch.Tensor to k2.RaggedInt so that
    # we can use k2.append() later
    den_graphs.convert_attr_to_ragged_(name='aux_labels')

    # The motivation to concatenate num_graphs and den_graphs
    # is to reduce the number of calls to k2.intersect_dense.
    num_den_graphs = k2.cat([num_graphs, den_graphs])

    # NOTE: The a_to_b_map in k2.intersect_dense must be sorted
    # so the following reorders num_den_graphs.
    #
    # The following code computes a_to_b_map

    # [0, 1, 2, ... ]
    num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32)

    # [num_fsas, num_fsas, num_fsas, ... ]
    den_graphs_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32)

    # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
    num_den_graphs_indexes = torch.stack(
        [num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device)

    num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes)

    # [[0, 1, 2, ...]]
    a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1)

    # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ]
    a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device)

    num_den_lats = k2.intersect_dense(num_den_reordered_graphs,
                                      dense_fsa_vec,
                                      output_beam=10.0,
                                      a_to_b_map=a_to_b_map)

    num_den_tot_scores = num_den_lats.get_tot_scores(log_semiring=True,
                                                     use_double_scores=True)

    num_tot_scores = num_den_tot_scores[::2]
    den_tot_scores = num_den_tot_scores[1::2]

    tot_scores = num_tot_scores - den_scale * den_tot_scores
    tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames(
        tot_scores, supervision_segments[:, 2])
    return tot_score, tot_frames, all_frames
Пример #13
0
    def forward(ctx,
                a_fsas: Fsa,
                b_fsas: DenseFsaVec,
                out_fsa: List[Fsa],
                output_beam: float,
                unused_scores_a: torch.Tensor,
                unused_scores_b: torch.Tensor,
                a_to_b_map: Optional[torch.Tensor] = None,
                seqframe_idx_name: Optional[str] = None,
                frame_idx_name: Optional[str] = None) -> torch.Tensor:
        '''Intersect array of FSAs on CPU/GPU.

        Args:
          a_fsas:
            Input FsaVec, i.e., `decoding graphs`, one per sequence. It might
            just be a linear sequence of phones, or might be something more
            complicated. Must have number of FSAs equal to b_fsas.dim0(), if
            a_to_b_map not specified.
          b_fsas:
            Input FSAs that correspond to neural network output.
          out_fsa:
            A list containing ONLY one entry which will be set to the
            generated FSA on return. We pass it as a list since the return
            value can only be types of torch.Tensor in the `forward` function.
          output_beam:
            Pruning beam for the output of intersection (vs. best path);
            equivalent to kaldi's lattice-beam.  E.g. 8.
          unused_scores_a:
            It equals to `a_fsas.scores` and its sole purpose is for back
            propagation.
          unused_scores_b:
            It equals to `b_fsas.scores` and its sole purpose is for back
            propagation.
          a_to_b_map:
            Maps from FSA-index in a to FSA-index in b to use for it.
            If None, then we expect the number of FSAs in a_fsas to equal
            b_fsas.dim0().  If set, then it should be a Tensor with ndim=1
            and dtype=torch.int32, with a_to_b_map.shape[0] equal to the
            number of FSAs in a_fsas (i.e. a_fsas.shape[0] if
            len(a_fsas.shape) == 3, else 1); and elements 0 <= i < b_fsas.dim0().
          seqframe_idx_name:
            If set (e.g. to 'seqframe'), an attribute in the output will be
            created that encodes the sequence-index and the frame-index within
            that sequence; this is equivalent to a row-index into b_fsas.values,
            or, equivalently, an element in b_fsas.shape.
          frame_idx_name:
            If set (e.g. to 'frame', an attribute in the output will be created
            that contains the frame-index within the corresponding sequence.
        Returns:
           Return `out_fsa[0].scores`.
        '''
        assert len(out_fsa) == 1

        ragged_arc, arc_map_a, arc_map_b = _k2.intersect_dense(
            a_fsas=a_fsas.arcs,
            b_fsas=b_fsas.dense_fsa_vec,
            a_to_b_map=a_to_b_map,
            output_beam=output_beam)

        out_fsa[0] = Fsa(ragged_arc)

        for name, a_value in a_fsas.named_tensor_attr(include_scores=False):
            value = k2.index(a_value, arc_map_a)
            setattr(out_fsa[0], name, value)

        for name, a_value in a_fsas.named_non_tensor_attr():
            setattr(out_fsa[0], name, a_value)

        ctx.arc_map_a = arc_map_a
        ctx.arc_map_b = arc_map_b

        ctx.save_for_backward(unused_scores_a, unused_scores_b)

        seqframe_idx = None
        if frame_idx_name is not None:
            num_cols = b_fsas.dense_fsa_vec.scores_dim1()
            seqframe_idx = arc_map_b // num_cols
            shape = b_fsas.dense_fsa_vec.shape()
            fsa_idx0 = _k2.index_select(shape.row_ids(1), seqframe_idx)
            frame_idx = seqframe_idx - _k2.index_select(
                shape.row_splits(1), fsa_idx0)
            assert not hasattr(out_fsa[0], frame_idx_name)
            setattr(out_fsa[0], frame_idx_name, frame_idx)

        if seqframe_idx_name is not None:
            if seqframe_idx is None:
                num_cols = b_fsas.dense_fsa_vec.scores_dim1()
                seqframe_idx = arc_map_b // num_cols

            assert not hasattr(out_fsa[0], seqframe_idx_name)
            setattr(out_fsa[0], seqframe_idx_name, seqframe_idx)

        return out_fsa[0].scores
Пример #14
0
    def forward(ctx,
                a_fsas: Fsa,
                b_fsas: DenseFsaVec,
                out_fsa: List[Fsa],
                search_beam: float,
                output_beam: float,
                min_active_states: int,
                max_active_states: int,
                unused_scores_a: torch.Tensor,
                unused_scores_b: torch.Tensor,
                seqframe_idx_name: Optional[str] = None,
                frame_idx_name: Optional[str] = None) -> torch.Tensor:
        '''Intersect array of FSAs on CPU/GPU.

        Args:
          a_fsas:
            Input FsaVec, i.e., `decoding graphs`, one per sequence. It might
            just be a linear sequence of phones, or might be something more
            complicated. Must have either `a_fsas.shape[0] == b_fsas.dim0()`, or
            `a_fsas.shape[0] == 1` in which case the graph is shared.
          b_fsas:
            Input FSAs that correspond to neural network output.
          out_fsa:
            A list containing ONLY one entry which will be set to the
            generated FSA on return. We pass it as a list since the return
            value can only be types of torch.Tensor in the `forward` function.
          search_beam:
            Decoding beam, e.g. 20.  Smaller is faster, larger is more exact
            (less pruning). This is the default value; it may be modified by
            `min_active_states` and `max_active_states`.
          output_beam:
            Pruning beam for the output of intersection (vs. best path);
            equivalent to kaldi's lattice-beam.  E.g. 8.
          max_active_states:
            Maximum number of FSA states that are allowed to be active on any
            given frame for any given intersection/composition task. This is
            advisory, in that it will try not to exceed that but may not always
            succeed. You can use a very large number if no constraint is needed.
          min_active_states:
            Minimum number of FSA states that are allowed to be active on any
            given frame for any given intersection/composition task. This is
            advisory, in that it will try not to have fewer than this number
            active. Set it to zero if there is no constraint.
          unused_scores_a:
            It equals to `a_fsas.scores` and its sole purpose is for back
            propagation.
          unused_scores_b:
            It equals to `b_fsas.scores` and its sole purpose is for back
            propagation.
          seqframe_idx_name:
            If set (e.g. to 'seqframe'), an attribute in the output will be
            created that encodes the sequence-index and the frame-index within
            that sequence; this is equivalent to a row-index into b_fsas.values,
            or, equivalently, an element in b_fsas.shape.
          frame_idx_name:
            If set (e.g. to 'frame', an attribute in the output will be created
            that contains the frame-index within the corresponding sequence.
        Returns:
           Return `out_fsa[0].scores`.
        '''
        assert len(out_fsa) == 1

        ragged_arc, arc_map_a, arc_map_b = _k2.intersect_dense_pruned(
            a_fsas=a_fsas.arcs,
            b_fsas=b_fsas.dense_fsa_vec,
            search_beam=search_beam,
            output_beam=output_beam,
            min_active_states=min_active_states,
            max_active_states=max_active_states)

        out_fsa[0] = Fsa(ragged_arc)

        for name, a_value in a_fsas.named_tensor_attr(include_scores=False):
            value = k2.index(a_value, arc_map_a)
            setattr(out_fsa[0], name, value)

        for name, a_value in a_fsas.named_non_tensor_attr():
            setattr(out_fsa[0], name, a_value)

        ctx.arc_map_a = arc_map_a
        ctx.arc_map_b = arc_map_b

        ctx.save_for_backward(unused_scores_a, unused_scores_b)

        seqframe_idx = None
        if frame_idx_name is not None:
            num_cols = b_fsas.dense_fsa_vec.scores_dim1()
            seqframe_idx = arc_map_b // num_cols
            shape = b_fsas.dense_fsa_vec.shape()
            fsa_idx0 = _k2.index_select(shape.row_ids(1), seqframe_idx)
            frame_idx = seqframe_idx - _k2.index_select(
                shape.row_splits(1), fsa_idx0)
            assert not hasattr(out_fsa[0], frame_idx_name)
            setattr(out_fsa[0], frame_idx_name, frame_idx)

        if seqframe_idx_name is not None:
            if seqframe_idx is None:
                num_cols = b_fsas.dense_fsa_vec.scores_dim1()
                seqframe_idx = arc_map_b // num_cols

            assert not hasattr(out_fsa[0], seqframe_idx_name)
            setattr(out_fsa[0], seqframe_idx_name, seqframe_idx)

        return out_fsa[0].scores
Пример #15
0
def rescore_with_n_best_list(lats: k2.Fsa, G: k2.Fsa,
                             num_paths: int) -> k2.Fsa:
    '''Decode using n-best list with LM rescoring.

    `lats` is a decoding lattice, which has 3 axes. This function first
    extracts `num_paths` paths from `lats` for each sequence using
    `k2.random_paths`. The `am_scores` of these paths are computed.
    For each path, its `lm_scores` is computed using `G` (which is an LM).
    The final `tot_scores` is the sum of `am_scores` and `lm_scores`.
    The path with the greatest `tot_scores` within a sequence is used
    as the decoding output.

    Args:
      lats:
        An FsaVec. It can be the output of `k2.intersect_dense_pruned`.
      G:
        An FsaVec representing the language model (LM). Note that it
        is an FsaVec, but it contains only one Fsa.
      num_paths:
        It is the size `n` in `n-best` list.
    Returns:
      An FsaVec representing the best decoding path for each sequence
      in the lattice.
    '''
    device = lats.device

    assert len(lats.shape) == 3
    assert hasattr(lats, 'aux_labels')
    assert hasattr(lats, 'lm_scores')

    assert G.shape == (1, None, None)
    assert G.device == device
    assert hasattr(G, 'aux_labels') is False

    # First, extract `num_paths` paths for each sequence.
    # paths is a k2.RaggedInt with axes [seq][path][arc_pos]
    paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)

    # word_seqs is a k2.RaggedInt sharing the same shape as `paths`
    # but it contains word IDs. Note that it also contains 0s and -1s.
    # The last entry in each sublist is -1.
    word_seqs = k2.index(lats.aux_labels, paths)

    # Remove epsilons and -1 from word_seqs
    word_seqs = k2.ragged.remove_values_leq(word_seqs, 0)

    # Remove repeated sequences to avoid redundant computation later.
    #
    # unique_word_seqs is still a k2.RaggedInt with 3 axes [seq][path][word]
    # except that there are no repeated paths with the same word_seq
    # within a seq.
    #
    # num_repeats is also a k2.RaggedInt with 2 axes containing the
    # multiplicities of each path.
    # num_repeats.num_elements() == unique_word_seqs.num_elements()
    #
    # Since k2.ragged.unique_sequences will reorder paths within a seq,
    # `new2old` is a 1-D torch.Tensor mapping from the output path index
    # to the input path index.
    # new2old.numel() == unique_word_seqs.num_elements()
    unique_word_seqs, num_repeats, new2old = k2.ragged.unique_sequences(
        word_seqs, need_num_repeats=True, need_new2old_indexes=True)

    seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0)

    # path_to_seq_map is a 1-D torch.Tensor.
    # path_to_seq_map[i] is the seq to which the i-th path
    # belongs.
    path_to_seq_map = seq_to_path_shape.row_ids(1)

    # Remove the seq axis.
    # Now unique_word_seqs has only two axes [path][word]
    unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0)

    # word_fsas is an FsaVec with axes [path][state][arc]
    word_fsas = k2.linear_fsa(unique_word_seqs)

    word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas)

    am_scores = compute_am_scores(lats, word_fsas_with_epsilon_loops,
                                  path_to_seq_map)

    # Now compute lm_scores
    b_to_a_map = torch.zeros_like(path_to_seq_map)
    lm_path_lats = _intersect_device(G,
                                     word_fsas_with_epsilon_loops,
                                     b_to_a_map=b_to_a_map,
                                     sorted_match_a=True)
    lm_path_lats = k2.top_sort(k2.connect(lm_path_lats.to('cpu'))).to(device)
    lm_scores = lm_path_lats.get_tot_scores(True, True)

    tot_scores = am_scores + lm_scores

    # Remember that we used `k2.ragged.unique_sequences` to remove repeated
    # paths to avoid redundant computation in `k2.intersect_device`.
    # Now we use `num_repeats` to correct the scores for each path.
    #
    # NOTE(fangjun): It is commented out as it leads to a worse WER
    # tot_scores = tot_scores * num_repeats.values()

    # TODO(fangjun): We may need to add `k2.RaggedDouble`
    ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape,
                                       tot_scores.to(torch.float32))
    argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)

    # Use k2.index here since argmax_indexes' dtype is torch.int32
    best_path_indexes = k2.index(new2old, argmax_indexes)

    paths = k2.ragged.remove_axis(paths, 0)

    # best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
    best_paths = k2.index(paths, best_path_indexes)

    # labels is a k2.RaggedInt with 2 axes [path][phone_id]
    # Note that it contains -1s.
    labels = k2.index(lats.labels.contiguous(), best_paths)

    labels = k2.ragged.remove_values_eq(labels, -1)

    # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so
    # aux_labels is also a k2.RaggedInt with 2 axes
    aux_labels = k2.index(lats.aux_labels, best_paths.values())

    best_path_fsas = k2.linear_fsa(labels)
    best_path_fsas.aux_labels = aux_labels

    return best_path_fsas
def nbest_decoding(lats: k2.Fsa, num_paths: int):
    '''
    (Ideas of this function are from Dan)

    It implements something like CTC prefix beam search using n-best lists

    The basic idea is to first extra n-best paths from the given lattice,
    build a word seqs from these paths, and compute the total scores
    of these sequences in the log-semiring. The one with the max score
    is used as the decoding output.
    '''

    # First, extract `num_paths` paths for each sequence.
    # paths is a k2.RaggedInt with axes [seq][path][arc_pos]
    paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)

    # word_seqs is a k2.RaggedInt sharing the same shape as `paths`
    # but it contains word IDs. Note that it also contains 0s and -1s.
    # The last entry in each sublist is -1.

    word_seqs = k2.index(lats.aux_labels, paths)
    # Note: the above operation supports also the case when
    # lats.aux_labels is a ragged tensor. In that case,
    # `remove_axis=True` is used inside the pybind11 binding code,
    # so the resulting `word_seqs` still has 3 axes, like `paths`.
    # The 3 axes are [seq][path][word]

    # Remove epsilons and -1 from word_seqs
    word_seqs = k2.ragged.remove_values_leq(word_seqs, 0)

    # Remove repeated sequences to avoid redundant computation later.
    #
    # Since k2.ragged.unique_sequences will reorder paths within a seq,
    # `new2old` is a 1-D torch.Tensor mapping from the output path index
    # to the input path index.
    # new2old.numel() == unique_word_seqs.num_elements()
    unique_word_seqs, _, new2old = k2.ragged.unique_sequences(
        word_seqs, need_num_repeats=False, need_new2old_indexes=True)
    # Note: unique_word_seqs still has the same axes as word_seqs

    seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0)

    # path_to_seq_map is a 1-D torch.Tensor.
    # path_to_seq_map[i] is the seq to which the i-th path
    # belongs.
    path_to_seq_map = seq_to_path_shape.row_ids(1)

    # Remove the seq axis.
    # Now unique_word_seqs has only two axes [path][word]
    unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0)

    # word_fsas is an FsaVec with axes [path][state][arc]
    word_fsas = k2.linear_fsa(unique_word_seqs)

    word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas)

    # lats has phone IDs as labels and word IDs as aux_labels.
    # inv_lats has word IDs as labels and phone IDs as aux_labels
    inv_lats = k2.invert(lats)
    inv_lats = k2.arc_sort(inv_lats)  # no-op if inv_lats is already arc-sorted

    path_lats = k2.intersect_device(inv_lats,
                                    word_fsas_with_epsilon_loops,
                                    b_to_a_map=path_to_seq_map,
                                    sorted_match_a=True)
    # path_lats has word IDs as labels and phone IDs as aux_labels

    path_lats = k2.top_sort(k2.connect(path_lats.to('cpu')).to(lats.device))

    tot_scores = path_lats.get_tot_scores(True, True)
    # RaggedFloat currently supports float32 only.
    # We may bind Ragged<double> as RaggedDouble if needed.
    ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape,
                                       tot_scores.to(torch.float32))

    argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)

    # Since we invoked `k2.ragged.unique_sequences`, which reorders
    # the index from `paths`, we use `new2old`
    # here to convert argmax_indexes to the indexes into `paths`.
    #
    # Use k2.index here since argmax_indexes' dtype is torch.int32
    best_path_indexes = k2.index(new2old, argmax_indexes)

    paths_2axes = k2.ragged.remove_axis(paths, 0)

    # best_paths is a k2.RaggedInt with 2 axes [path][arc_pos]
    best_paths = k2.index(paths_2axes, best_path_indexes)

    # labels is a k2.RaggedInt with 2 axes [path][phone_id]
    # Note that it contains -1s.
    labels = k2.index(lats.labels.contiguous(), best_paths)

    labels = k2.ragged.remove_values_eq(labels, -1)

    # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so
    # aux_labels is also a k2.RaggedInt with 2 axes
    aux_labels = k2.index(lats.aux_labels, best_paths.values())

    best_path_fsas = k2.linear_fsa(labels)
    best_path_fsas.aux_labels = aux_labels

    return best_path_fsas
Пример #17
0
    def forward(
            self, nnet_output: torch.Tensor, texts: List,
            supervision_segments: torch.Tensor
    ) -> Tuple[torch.Tensor, int, int]:
        num_graphs, den_graphs = self.graph_compiler.compile(
            texts, self.P, replicate_den=False)

        dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

        device = num_graphs.device

        num_fsas = num_graphs.shape[0]
        assert dense_fsa_vec.dim0() == num_fsas

        assert den_graphs.shape[0] == 1

        # the aux_labels of num_graphs is k2.RaggedInt
        # but it is torch.Tensor for den_graphs.
        #
        # The following converts den_graphs.aux_labels
        # from torch.Tensor to k2.RaggedInt so that
        # we can use k2.append() later
        den_graphs.convert_attr_to_ragged_(name='aux_labels')

        num_den_graphs = k2.cat([num_graphs, den_graphs])

        # NOTE: The a_to_b_map in k2.intersect_dense must be sorted
        # so the following reorders num_den_graphs.

        # [0, 1, 2, ... ]
        num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32)

        # [num_fsas, num_fsas, num_fsas, ... ]
        den_graphs_indexes = torch.tensor([num_fsas] * num_fsas,
                                          dtype=torch.int32)

        # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
        num_den_graphs_indexes = torch.stack(
            [num_graphs_indexes,
             den_graphs_indexes]).t().reshape(-1).to(device)

        num_den_reordered_graphs = k2.index(num_den_graphs,
                                            num_den_graphs_indexes)

        # [[0, 1, 2, ...]]
        a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1)

        # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ]
        a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device)

        num_den_lats = k2.intersect_dense(num_den_reordered_graphs,
                                          dense_fsa_vec,
                                          output_beam=10.0,
                                          a_to_b_map=a_to_b_map)

        num_den_tot_scores = num_den_lats.get_tot_scores(
            log_semiring=True, use_double_scores=True)

        num_tot_scores = num_den_tot_scores[::2]
        den_tot_scores = num_den_tot_scores[1::2]

        tot_scores = num_tot_scores - self.den_scale * den_tot_scores
        tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames(
            tot_scores, supervision_segments[:, 2])
        return tot_score, tot_frames, all_frames
Пример #18
0
    def forward(ctx, a_fsas: Fsa, b_fsas: DenseFsaVec, out_fsa: List[Fsa],
                output_beam: float, unused_scores_a: torch.Tensor,
                unused_scores_b: torch.Tensor) -> torch.Tensor:
        '''Intersect array of FSAs on CPU/GPU.

        Args:
          a_fsas:
            Input FsaVec, i.e., `decoding graphs`, one per sequence. It might
            just be a linear sequence of phones, or might be something more
            complicated. Must have `a_fsas.shape[0] == b_fsas.dim0()`.
          b_fsas:
            Input FSAs that correspond to neural network output.
          out_fsa:
            A list containing ONLY one entry which will be set to the
            generated FSA on return. We pass it as a list since the return
            value can only be types of torch.Tensor in the `forward` function.
          search_beam:
            Decoding beam, e.g. 20.  Smaller is faster, larger is more exact
            (less pruning). This is the default value; it may be modified by
            `min_active_states` and `max_active_states`.
          output_beam:
            Pruning beam for the output of intersection (vs. best path);
            equivalent to kaldi's lattice-beam.  E.g. 8.
          max_active_states:
            Maximum number of FSA states that are allowed to be active on any
            given frame for any given intersection/composition task. This is
            advisory, in that it will try not to exceed that but may not always
            succeed. You can use a very large number if no constraint is needed.
          min_active_states:
            Minimum number of FSA states that are allowed to be active on any
            given frame for any given intersection/composition task. This is
            advisory, in that it will try not to have fewer than this number
            active. Set it to zero if there is no constraint.
          unused_scores_a:
            It equals to `a_fsas.scores` and its sole purpose is for back
            propagation.
          unused_scores_b:
            It equals to `b_fsas.scores` and its sole purpose is for back
            propagation.
        Returns:
           Return `out_fsa[0].scores`.
        '''
        assert len(out_fsa) == 1

        ragged_arc, arc_map_a, arc_map_b = _k2.intersect_dense(
            a_fsas=a_fsas.arcs,
            b_fsas=b_fsas.dense_fsa_vec,
            output_beam=output_beam)

        out_fsa[0] = Fsa(ragged_arc)

        for name, a_value in a_fsas.named_tensor_attr(include_scores=False):
            value = k2.index(a_value, arc_map_a)
            setattr(out_fsa[0], name, value)

        for name, a_value in a_fsas.named_non_tensor_attr():
            setattr(out_fsa[0], name, a_value)

        ctx.arc_map_a = arc_map_a
        ctx.arc_map_b = arc_map_b

        ctx.save_for_backward(unused_scores_a, unused_scores_b)

        return out_fsa[0].scores