Esempio n. 1
0
    def forward(ctx, fsas: Fsa, out_fsa: List[Fsa],
                unused_fsas_scores: torch.Tensor) -> torch.Tensor:
        '''Compute the union of all fsas in a FsaVec.

        Args:
          fsas:
            The input FsaVec. Caution: We require that each fsa in the FsaVec
            is non-empty (i.e., with at least two states).
          out_fsa:
            A list containing one entry. Since this function can only return
            values of type `torch.Tensor`, we return the union result in the
            list.
          unused_fsas_scores:
            It is the same as `fsas.scores`, whose sole purpose is for autograd.
            It is not used in this function.
        '''
        need_arc_map = True
        ragged_arc, arc_map = _k2.union(fsas.arcs, need_arc_map)
        out_fsa[0] = Fsa.from_ragged_arc(ragged_arc)

        for name, value in fsas.named_tensor_attr():
            if name == 'scores':
                continue
            value = _k2.index_select(value, arc_map)
            setattr(out_fsa[0], name, value)

        for name, value in fsas.named_non_tensor_attr():
            setattr(out_fsa[0], name, value)
        ctx.arc_map = arc_map
        ctx.save_for_backward(unused_fsas_scores)

        return out_fsa[0].scores  # the return value will be discarded
Esempio n. 2
0
    def forward(ctx, src: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
        '''Returns a new tensor which indexes the input tensor along dimension 0
        using the entries in `index`.

        If the entry in `index` is -1, then the corresponding entry in the
        returned tensor is 0.

        Caution:
          `index.dtype == torch.int32` and `index.ndim == 1`.

        Args:
          src:
            The input tensor. Either 1-D or 2-D with dtype torch.int32 or
            torch.float32.
          index:
            1-D tensor of dtype torch.int32 containing the indexes.
            If an entry is -1, the corresponding entry in the returned value
            is 0. The elements of `index` should be in the range
            `[-1..src.shape[0]-1]`.

        Returns:
          A tensor with shape (index.numel(), *src.shape[1:]) and dtype the
          same as `src`, e.g. if `src.ndim == 1`, ans.shape would be
          (index.shape[0],); if `src.ndim == 2`, ans.shape would be
          (index.shape[0], src.shape[1]).
          Will satisfy `ans[i] == src[index[i]]` if `src.ndim == 1`,
          or `ans[i,j] == src[index[i],j]` if `src.ndim == 2`, except for
          entries where `index[i] == -1` which will be zero.
        '''
        ctx.save_for_backward(src, index)
        return _k2.index_select(src, index)
Esempio n. 3
0
File: ops.py Progetto: entn-at/k2
def compose_arc_maps(step1_arc_map: torch.Tensor,
                     step2_arc_map: torch.Tensor) -> torch.Tensor:
    '''Compose arc maps from two Fsa operations.

    It implements:

        - ans_arc_map[i] = step1_arc_map[step2_arc_map[i]] if
          step2_arc_map[i] is not -1
        - ans_arc_map[i] = -1 if step2_arc_map[i] is -1

    for i in 0 to `step2_arc_map.numel() - 1`.

    Args:
      step1_arc_map:
        A 1-D tensor with dtype torch.int32 from the first Fsa operation.
      step2_arc_map:
        A 1-D tensor with dtype torch.int32 from the second Fsa operation.
    Returns:
      Return a 1-D tensor with dtype torch.int32. It has the same number
      of elements as step2_arc_map. That is,
      ans_arc_map.shape == step2_arc_map.shape.
    '''
    assert step1_arc_map.ndim == 1
    assert step1_arc_map.dtype == torch.int32

    assert step2_arc_map.ndim == 1
    assert step2_arc_map.dtype == torch.int32

    return _k2.index_select(step1_arc_map, step2_arc_map, default_value=-1)
Esempio n. 4
0
    def forward(ctx, out_fsa: Fsa, unused_in_fsa_scores: torch.Tensor,
                arc_map: torch.Tensor) -> torch.Tensor:
        if False:
            # TODO(fangjun): this is for debugging only. Can be removed.
            expected_scores = _k2.index_select(unused_in_fsa_scores, arc_map)
            assert torch.all(torch.eq(out_fsa.scores, expected_scores))

        ctx.save_for_backward(unused_in_fsa_scores, arc_map)
        return out_fsa.scores
Esempio n. 5
0
 def backward(ctx, out_grad: torch.Tensor) -> Tuple[torch.Tensor, None]:
     indexes = ctx.indexes
     src, = ctx.saved_tensors
     expanded = _k2.index_select(out_grad, indexes.row_ids(1))
     ans = torch.zeros(src.shape,
                       dtype=torch.float32,
                       device=src.device,
                       requires_grad=False)
     _k2.index_add(indexes.values(), expanded, ans)
     return ans, None
Esempio n. 6
0
    def backward(
        ctx, out_fsa_scores_grad: torch.Tensor
    ) -> Tuple[None, torch.Tensor, None]:  # noqa
        unused_in_fsa_scores, = ctx.saved_tensors
        arc_map = ctx.arc_map

        expanded = _k2.index_select(out_fsa_scores_grad, arc_map.row_ids(1))
        ans = torch.zeros(unused_in_fsa_scores.shape,
                          dtype=torch.float32,
                          device=unused_in_fsa_scores.device,
                          requires_grad=False)
        _k2.index_add(arc_map.values(), expanded, ans)

        return (
            None,  # out_fsa
            ans,  # unused_in_fsa_scores
            None  # arc_map
        )
Esempio n. 7
0
    def forward(ctx,
                a_fsas: Fsa,
                b_fsas: DenseFsaVec,
                out_fsa: List[Fsa],
                output_beam: float,
                unused_scores_a: torch.Tensor,
                unused_scores_b: torch.Tensor,
                a_to_b_map: Optional[torch.Tensor] = None,
                seqframe_idx_name: Optional[str] = None,
                frame_idx_name: Optional[str] = None) -> torch.Tensor:
        '''Intersect array of FSAs on CPU/GPU.

        Args:
          a_fsas:
            Input FsaVec, i.e., `decoding graphs`, one per sequence. It might
            just be a linear sequence of phones, or might be something more
            complicated. Must have number of FSAs equal to b_fsas.dim0(), if
            a_to_b_map not specified.
          b_fsas:
            Input FSAs that correspond to neural network output.
          out_fsa:
            A list containing ONLY one entry which will be set to the
            generated FSA on return. We pass it as a list since the return
            value can only be types of torch.Tensor in the `forward` function.
          output_beam:
            Pruning beam for the output of intersection (vs. best path);
            equivalent to kaldi's lattice-beam.  E.g. 8.
          unused_scores_a:
            It equals to `a_fsas.scores` and its sole purpose is for back
            propagation.
          unused_scores_b:
            It equals to `b_fsas.scores` and its sole purpose is for back
            propagation.
          a_to_b_map:
            Maps from FSA-index in a to FSA-index in b to use for it.
            If None, then we expect the number of FSAs in a_fsas to equal
            b_fsas.dim0().  If set, then it should be a Tensor with ndim=1
            and dtype=torch.int32, with a_to_b_map.shape[0] equal to the
            number of FSAs in a_fsas (i.e. a_fsas.shape[0] if
            len(a_fsas.shape) == 3, else 1); and elements 0 <= i < b_fsas.dim0().
          seqframe_idx_name:
            If set (e.g. to 'seqframe'), an attribute in the output will be
            created that encodes the sequence-index and the frame-index within
            that sequence; this is equivalent to a row-index into b_fsas.values,
            or, equivalently, an element in b_fsas.shape.
          frame_idx_name:
            If set (e.g. to 'frame', an attribute in the output will be created
            that contains the frame-index within the corresponding sequence.
        Returns:
           Return `out_fsa[0].scores`.
        '''
        assert len(out_fsa) == 1

        ragged_arc, arc_map_a, arc_map_b = _k2.intersect_dense(
            a_fsas=a_fsas.arcs,
            b_fsas=b_fsas.dense_fsa_vec,
            a_to_b_map=a_to_b_map,
            output_beam=output_beam)

        out_fsa[0] = Fsa(ragged_arc)

        for name, a_value in a_fsas.named_tensor_attr(include_scores=False):
            value = k2.index(a_value, arc_map_a)
            setattr(out_fsa[0], name, value)

        for name, a_value in a_fsas.named_non_tensor_attr():
            setattr(out_fsa[0], name, a_value)

        ctx.arc_map_a = arc_map_a
        ctx.arc_map_b = arc_map_b

        ctx.save_for_backward(unused_scores_a, unused_scores_b)

        seqframe_idx = None
        if frame_idx_name is not None:
            num_cols = b_fsas.dense_fsa_vec.scores_dim1()
            seqframe_idx = arc_map_b // num_cols
            shape = b_fsas.dense_fsa_vec.shape()
            fsa_idx0 = _k2.index_select(shape.row_ids(1), seqframe_idx)
            frame_idx = seqframe_idx - _k2.index_select(
                shape.row_splits(1), fsa_idx0)
            assert not hasattr(out_fsa[0], frame_idx_name)
            setattr(out_fsa[0], frame_idx_name, frame_idx)

        if seqframe_idx_name is not None:
            if seqframe_idx is None:
                num_cols = b_fsas.dense_fsa_vec.scores_dim1()
                seqframe_idx = arc_map_b // num_cols

            assert not hasattr(out_fsa[0], seqframe_idx_name)
            setattr(out_fsa[0], seqframe_idx_name, seqframe_idx)

        return out_fsa[0].scores
Esempio n. 8
0
    def forward(ctx,
                a_fsas: Fsa,
                b_fsas: DenseFsaVec,
                out_fsa: List[Fsa],
                search_beam: float,
                output_beam: float,
                min_active_states: int,
                max_active_states: int,
                unused_scores_a: torch.Tensor,
                unused_scores_b: torch.Tensor,
                seqframe_idx_name: Optional[str] = None,
                frame_idx_name: Optional[str] = None) -> torch.Tensor:
        '''Intersect array of FSAs on CPU/GPU.

        Args:
          a_fsas:
            Input FsaVec, i.e., `decoding graphs`, one per sequence. It might
            just be a linear sequence of phones, or might be something more
            complicated. Must have either `a_fsas.shape[0] == b_fsas.dim0()`, or
            `a_fsas.shape[0] == 1` in which case the graph is shared.
          b_fsas:
            Input FSAs that correspond to neural network output.
          out_fsa:
            A list containing ONLY one entry which will be set to the
            generated FSA on return. We pass it as a list since the return
            value can only be types of torch.Tensor in the `forward` function.
          search_beam:
            Decoding beam, e.g. 20.  Smaller is faster, larger is more exact
            (less pruning). This is the default value; it may be modified by
            `min_active_states` and `max_active_states`.
          output_beam:
            Pruning beam for the output of intersection (vs. best path);
            equivalent to kaldi's lattice-beam.  E.g. 8.
          max_active_states:
            Maximum number of FSA states that are allowed to be active on any
            given frame for any given intersection/composition task. This is
            advisory, in that it will try not to exceed that but may not always
            succeed. You can use a very large number if no constraint is needed.
          min_active_states:
            Minimum number of FSA states that are allowed to be active on any
            given frame for any given intersection/composition task. This is
            advisory, in that it will try not to have fewer than this number
            active. Set it to zero if there is no constraint.
          unused_scores_a:
            It equals to `a_fsas.scores` and its sole purpose is for back
            propagation.
          unused_scores_b:
            It equals to `b_fsas.scores` and its sole purpose is for back
            propagation.
          seqframe_idx_name:
            If set (e.g. to 'seqframe'), an attribute in the output will be
            created that encodes the sequence-index and the frame-index within
            that sequence; this is equivalent to a row-index into b_fsas.values,
            or, equivalently, an element in b_fsas.shape.
          frame_idx_name:
            If set (e.g. to 'frame', an attribute in the output will be created
            that contains the frame-index within the corresponding sequence.
        Returns:
           Return `out_fsa[0].scores`.
        '''
        assert len(out_fsa) == 1

        ragged_arc, arc_map_a, arc_map_b = _k2.intersect_dense_pruned(
            a_fsas=a_fsas.arcs,
            b_fsas=b_fsas.dense_fsa_vec,
            search_beam=search_beam,
            output_beam=output_beam,
            min_active_states=min_active_states,
            max_active_states=max_active_states)

        out_fsa[0] = Fsa(ragged_arc)

        for name, a_value in a_fsas.named_tensor_attr(include_scores=False):
            value = k2.index(a_value, arc_map_a)
            setattr(out_fsa[0], name, value)

        for name, a_value in a_fsas.named_non_tensor_attr():
            setattr(out_fsa[0], name, a_value)

        ctx.arc_map_a = arc_map_a
        ctx.arc_map_b = arc_map_b

        ctx.save_for_backward(unused_scores_a, unused_scores_b)

        seqframe_idx = None
        if frame_idx_name is not None:
            num_cols = b_fsas.dense_fsa_vec.scores_dim1()
            seqframe_idx = arc_map_b // num_cols
            shape = b_fsas.dense_fsa_vec.shape()
            fsa_idx0 = _k2.index_select(shape.row_ids(1), seqframe_idx)
            frame_idx = seqframe_idx - _k2.index_select(
                shape.row_splits(1), fsa_idx0)
            assert not hasattr(out_fsa[0], frame_idx_name)
            setattr(out_fsa[0], frame_idx_name, frame_idx)

        if seqframe_idx_name is not None:
            if seqframe_idx is None:
                num_cols = b_fsas.dense_fsa_vec.scores_dim1()
                seqframe_idx = arc_map_b // num_cols

            assert not hasattr(out_fsa[0], seqframe_idx_name)
            setattr(out_fsa[0], seqframe_idx_name, seqframe_idx)

        return out_fsa[0].scores
Esempio n. 9
0
    def forward(ctx, a_fsas: Fsa, b_fsas: DenseFsaVec, out_fsa: List[Fsa],
                beam: float, max_active_states: int, min_active_states: int,
                unused_scores_a: torch.Tensor,
                unused_scores_b: torch.Tensor) -> torch.Tensor:
        '''Intersect array of FSAs on CPU/GPU.

        Args:
          a_fsas:
            Input FsaVec, i.e., `decoding graphs`, one per sequence. It might
            just be a linear sequence of phones, or might be something more
            complicated. Must have either `a_fsas.shape[0] == b_fsas.dim0()`, or
            `a_fsas.shape[0] == 1` in which case the graph is shared.
          b_fsas:
            Input FSAs that correspond to neural network output.
          out_fsa:
            A list containing ONLY one entry which will be set to the
            generated FSA on return. We pass it as a list since the return
            value can only be types of torch.Tensor in the `forward` function.
          beam:
            Decoding beam, e.g. 10.  Smaller is faster, larger is more exact
            (less pruning). This is the default value; it may be modified by
            `min_active_states` and `max_active_states`.
          max_active_states:
            Maximum number of FSA states that are allowed to be active on any
            given frame for any given intersection/composition task. This is
            advisory, in that it will try not to exceed that but may not always
            succeed. You can use a very large number if no constraint is needed.
          min_active_states:
            Minimum number of FSA states that are allowed to be active on any
            given frame for any given intersection/composition task. This is
            advisory, in that it will try not to have fewer than this number
            active. Set it to zero if there is no constraint.
          unused_scores_a:
            It equals to `a_fsas.scores` and its sole purpose is for back
            propagation.
          unused_scores_b:
            It equals to `b_fsas.scores` and its sole purpose is for back
            propagation.
        Returns:
           Return `out_fsa[0].scores`.
        '''
        assert len(out_fsa) == 1

        ragged_arc, arc_map_a, arc_map_b = _k2.intersect_dense_pruned(
            a_fsas=a_fsas.arcs,
            b_fsas=b_fsas.dense_fsa_vec,
            beam=beam,
            max_active_states=max_active_states,
            min_active_states=min_active_states)

        out_fsa[0] = Fsa.from_ragged_arc(ragged_arc)

        for name, a_value in a_fsas.named_tensor_attr():
            if name == 'scores':
                continue
            value = _k2.index_select(a_value, arc_map_a)
            setattr(out_fsa[0], name, value)

        for name, a_value in a_fsas.named_non_tensor_attr():
            setattr(out_fsa[0], name, a_value)

        ctx.arc_map_a = arc_map_a
        ctx.arc_map_b = arc_map_b

        ctx.save_for_backward(unused_scores_a, unused_scores_b)

        return out_fsa[0].scores