def forward(ctx, fsas: Fsa, out_fsa: List[Fsa], unused_fsas_scores: torch.Tensor) -> torch.Tensor: '''Compute the union of all fsas in a FsaVec. Args: fsas: The input FsaVec. Caution: We require that each fsa in the FsaVec is non-empty (i.e., with at least two states). out_fsa: A list containing one entry. Since this function can only return values of type `torch.Tensor`, we return the union result in the list. unused_fsas_scores: It is the same as `fsas.scores`, whose sole purpose is for autograd. It is not used in this function. ''' need_arc_map = True ragged_arc, arc_map = _k2.union(fsas.arcs, need_arc_map) out_fsa[0] = Fsa.from_ragged_arc(ragged_arc) for name, value in fsas.named_tensor_attr(): if name == 'scores': continue value = _k2.index_select(value, arc_map) setattr(out_fsa[0], name, value) for name, value in fsas.named_non_tensor_attr(): setattr(out_fsa[0], name, value) ctx.arc_map = arc_map ctx.save_for_backward(unused_fsas_scores) return out_fsa[0].scores # the return value will be discarded
def forward(ctx, src: torch.Tensor, index: torch.Tensor) -> torch.Tensor: '''Returns a new tensor which indexes the input tensor along dimension 0 using the entries in `index`. If the entry in `index` is -1, then the corresponding entry in the returned tensor is 0. Caution: `index.dtype == torch.int32` and `index.ndim == 1`. Args: src: The input tensor. Either 1-D or 2-D with dtype torch.int32 or torch.float32. index: 1-D tensor of dtype torch.int32 containing the indexes. If an entry is -1, the corresponding entry in the returned value is 0. The elements of `index` should be in the range `[-1..src.shape[0]-1]`. Returns: A tensor with shape (index.numel(), *src.shape[1:]) and dtype the same as `src`, e.g. if `src.ndim == 1`, ans.shape would be (index.shape[0],); if `src.ndim == 2`, ans.shape would be (index.shape[0], src.shape[1]). Will satisfy `ans[i] == src[index[i]]` if `src.ndim == 1`, or `ans[i,j] == src[index[i],j]` if `src.ndim == 2`, except for entries where `index[i] == -1` which will be zero. ''' ctx.save_for_backward(src, index) return _k2.index_select(src, index)
def compose_arc_maps(step1_arc_map: torch.Tensor, step2_arc_map: torch.Tensor) -> torch.Tensor: '''Compose arc maps from two Fsa operations. It implements: - ans_arc_map[i] = step1_arc_map[step2_arc_map[i]] if step2_arc_map[i] is not -1 - ans_arc_map[i] = -1 if step2_arc_map[i] is -1 for i in 0 to `step2_arc_map.numel() - 1`. Args: step1_arc_map: A 1-D tensor with dtype torch.int32 from the first Fsa operation. step2_arc_map: A 1-D tensor with dtype torch.int32 from the second Fsa operation. Returns: Return a 1-D tensor with dtype torch.int32. It has the same number of elements as step2_arc_map. That is, ans_arc_map.shape == step2_arc_map.shape. ''' assert step1_arc_map.ndim == 1 assert step1_arc_map.dtype == torch.int32 assert step2_arc_map.ndim == 1 assert step2_arc_map.dtype == torch.int32 return _k2.index_select(step1_arc_map, step2_arc_map, default_value=-1)
def forward(ctx, out_fsa: Fsa, unused_in_fsa_scores: torch.Tensor, arc_map: torch.Tensor) -> torch.Tensor: if False: # TODO(fangjun): this is for debugging only. Can be removed. expected_scores = _k2.index_select(unused_in_fsa_scores, arc_map) assert torch.all(torch.eq(out_fsa.scores, expected_scores)) ctx.save_for_backward(unused_in_fsa_scores, arc_map) return out_fsa.scores
def backward(ctx, out_grad: torch.Tensor) -> Tuple[torch.Tensor, None]: indexes = ctx.indexes src, = ctx.saved_tensors expanded = _k2.index_select(out_grad, indexes.row_ids(1)) ans = torch.zeros(src.shape, dtype=torch.float32, device=src.device, requires_grad=False) _k2.index_add(indexes.values(), expanded, ans) return ans, None
def backward( ctx, out_fsa_scores_grad: torch.Tensor ) -> Tuple[None, torch.Tensor, None]: # noqa unused_in_fsa_scores, = ctx.saved_tensors arc_map = ctx.arc_map expanded = _k2.index_select(out_fsa_scores_grad, arc_map.row_ids(1)) ans = torch.zeros(unused_in_fsa_scores.shape, dtype=torch.float32, device=unused_in_fsa_scores.device, requires_grad=False) _k2.index_add(arc_map.values(), expanded, ans) return ( None, # out_fsa ans, # unused_in_fsa_scores None # arc_map )
def forward(ctx, a_fsas: Fsa, b_fsas: DenseFsaVec, out_fsa: List[Fsa], output_beam: float, unused_scores_a: torch.Tensor, unused_scores_b: torch.Tensor, a_to_b_map: Optional[torch.Tensor] = None, seqframe_idx_name: Optional[str] = None, frame_idx_name: Optional[str] = None) -> torch.Tensor: '''Intersect array of FSAs on CPU/GPU. Args: a_fsas: Input FsaVec, i.e., `decoding graphs`, one per sequence. It might just be a linear sequence of phones, or might be something more complicated. Must have number of FSAs equal to b_fsas.dim0(), if a_to_b_map not specified. b_fsas: Input FSAs that correspond to neural network output. out_fsa: A list containing ONLY one entry which will be set to the generated FSA on return. We pass it as a list since the return value can only be types of torch.Tensor in the `forward` function. output_beam: Pruning beam for the output of intersection (vs. best path); equivalent to kaldi's lattice-beam. E.g. 8. unused_scores_a: It equals to `a_fsas.scores` and its sole purpose is for back propagation. unused_scores_b: It equals to `b_fsas.scores` and its sole purpose is for back propagation. a_to_b_map: Maps from FSA-index in a to FSA-index in b to use for it. If None, then we expect the number of FSAs in a_fsas to equal b_fsas.dim0(). If set, then it should be a Tensor with ndim=1 and dtype=torch.int32, with a_to_b_map.shape[0] equal to the number of FSAs in a_fsas (i.e. a_fsas.shape[0] if len(a_fsas.shape) == 3, else 1); and elements 0 <= i < b_fsas.dim0(). seqframe_idx_name: If set (e.g. to 'seqframe'), an attribute in the output will be created that encodes the sequence-index and the frame-index within that sequence; this is equivalent to a row-index into b_fsas.values, or, equivalently, an element in b_fsas.shape. frame_idx_name: If set (e.g. to 'frame', an attribute in the output will be created that contains the frame-index within the corresponding sequence. Returns: Return `out_fsa[0].scores`. ''' assert len(out_fsa) == 1 ragged_arc, arc_map_a, arc_map_b = _k2.intersect_dense( a_fsas=a_fsas.arcs, b_fsas=b_fsas.dense_fsa_vec, a_to_b_map=a_to_b_map, output_beam=output_beam) out_fsa[0] = Fsa(ragged_arc) for name, a_value in a_fsas.named_tensor_attr(include_scores=False): value = k2.index(a_value, arc_map_a) setattr(out_fsa[0], name, value) for name, a_value in a_fsas.named_non_tensor_attr(): setattr(out_fsa[0], name, a_value) ctx.arc_map_a = arc_map_a ctx.arc_map_b = arc_map_b ctx.save_for_backward(unused_scores_a, unused_scores_b) seqframe_idx = None if frame_idx_name is not None: num_cols = b_fsas.dense_fsa_vec.scores_dim1() seqframe_idx = arc_map_b // num_cols shape = b_fsas.dense_fsa_vec.shape() fsa_idx0 = _k2.index_select(shape.row_ids(1), seqframe_idx) frame_idx = seqframe_idx - _k2.index_select( shape.row_splits(1), fsa_idx0) assert not hasattr(out_fsa[0], frame_idx_name) setattr(out_fsa[0], frame_idx_name, frame_idx) if seqframe_idx_name is not None: if seqframe_idx is None: num_cols = b_fsas.dense_fsa_vec.scores_dim1() seqframe_idx = arc_map_b // num_cols assert not hasattr(out_fsa[0], seqframe_idx_name) setattr(out_fsa[0], seqframe_idx_name, seqframe_idx) return out_fsa[0].scores
def forward(ctx, a_fsas: Fsa, b_fsas: DenseFsaVec, out_fsa: List[Fsa], search_beam: float, output_beam: float, min_active_states: int, max_active_states: int, unused_scores_a: torch.Tensor, unused_scores_b: torch.Tensor, seqframe_idx_name: Optional[str] = None, frame_idx_name: Optional[str] = None) -> torch.Tensor: '''Intersect array of FSAs on CPU/GPU. Args: a_fsas: Input FsaVec, i.e., `decoding graphs`, one per sequence. It might just be a linear sequence of phones, or might be something more complicated. Must have either `a_fsas.shape[0] == b_fsas.dim0()`, or `a_fsas.shape[0] == 1` in which case the graph is shared. b_fsas: Input FSAs that correspond to neural network output. out_fsa: A list containing ONLY one entry which will be set to the generated FSA on return. We pass it as a list since the return value can only be types of torch.Tensor in the `forward` function. search_beam: Decoding beam, e.g. 20. Smaller is faster, larger is more exact (less pruning). This is the default value; it may be modified by `min_active_states` and `max_active_states`. output_beam: Pruning beam for the output of intersection (vs. best path); equivalent to kaldi's lattice-beam. E.g. 8. max_active_states: Maximum number of FSA states that are allowed to be active on any given frame for any given intersection/composition task. This is advisory, in that it will try not to exceed that but may not always succeed. You can use a very large number if no constraint is needed. min_active_states: Minimum number of FSA states that are allowed to be active on any given frame for any given intersection/composition task. This is advisory, in that it will try not to have fewer than this number active. Set it to zero if there is no constraint. unused_scores_a: It equals to `a_fsas.scores` and its sole purpose is for back propagation. unused_scores_b: It equals to `b_fsas.scores` and its sole purpose is for back propagation. seqframe_idx_name: If set (e.g. to 'seqframe'), an attribute in the output will be created that encodes the sequence-index and the frame-index within that sequence; this is equivalent to a row-index into b_fsas.values, or, equivalently, an element in b_fsas.shape. frame_idx_name: If set (e.g. to 'frame', an attribute in the output will be created that contains the frame-index within the corresponding sequence. Returns: Return `out_fsa[0].scores`. ''' assert len(out_fsa) == 1 ragged_arc, arc_map_a, arc_map_b = _k2.intersect_dense_pruned( a_fsas=a_fsas.arcs, b_fsas=b_fsas.dense_fsa_vec, search_beam=search_beam, output_beam=output_beam, min_active_states=min_active_states, max_active_states=max_active_states) out_fsa[0] = Fsa(ragged_arc) for name, a_value in a_fsas.named_tensor_attr(include_scores=False): value = k2.index(a_value, arc_map_a) setattr(out_fsa[0], name, value) for name, a_value in a_fsas.named_non_tensor_attr(): setattr(out_fsa[0], name, a_value) ctx.arc_map_a = arc_map_a ctx.arc_map_b = arc_map_b ctx.save_for_backward(unused_scores_a, unused_scores_b) seqframe_idx = None if frame_idx_name is not None: num_cols = b_fsas.dense_fsa_vec.scores_dim1() seqframe_idx = arc_map_b // num_cols shape = b_fsas.dense_fsa_vec.shape() fsa_idx0 = _k2.index_select(shape.row_ids(1), seqframe_idx) frame_idx = seqframe_idx - _k2.index_select( shape.row_splits(1), fsa_idx0) assert not hasattr(out_fsa[0], frame_idx_name) setattr(out_fsa[0], frame_idx_name, frame_idx) if seqframe_idx_name is not None: if seqframe_idx is None: num_cols = b_fsas.dense_fsa_vec.scores_dim1() seqframe_idx = arc_map_b // num_cols assert not hasattr(out_fsa[0], seqframe_idx_name) setattr(out_fsa[0], seqframe_idx_name, seqframe_idx) return out_fsa[0].scores
def forward(ctx, a_fsas: Fsa, b_fsas: DenseFsaVec, out_fsa: List[Fsa], beam: float, max_active_states: int, min_active_states: int, unused_scores_a: torch.Tensor, unused_scores_b: torch.Tensor) -> torch.Tensor: '''Intersect array of FSAs on CPU/GPU. Args: a_fsas: Input FsaVec, i.e., `decoding graphs`, one per sequence. It might just be a linear sequence of phones, or might be something more complicated. Must have either `a_fsas.shape[0] == b_fsas.dim0()`, or `a_fsas.shape[0] == 1` in which case the graph is shared. b_fsas: Input FSAs that correspond to neural network output. out_fsa: A list containing ONLY one entry which will be set to the generated FSA on return. We pass it as a list since the return value can only be types of torch.Tensor in the `forward` function. beam: Decoding beam, e.g. 10. Smaller is faster, larger is more exact (less pruning). This is the default value; it may be modified by `min_active_states` and `max_active_states`. max_active_states: Maximum number of FSA states that are allowed to be active on any given frame for any given intersection/composition task. This is advisory, in that it will try not to exceed that but may not always succeed. You can use a very large number if no constraint is needed. min_active_states: Minimum number of FSA states that are allowed to be active on any given frame for any given intersection/composition task. This is advisory, in that it will try not to have fewer than this number active. Set it to zero if there is no constraint. unused_scores_a: It equals to `a_fsas.scores` and its sole purpose is for back propagation. unused_scores_b: It equals to `b_fsas.scores` and its sole purpose is for back propagation. Returns: Return `out_fsa[0].scores`. ''' assert len(out_fsa) == 1 ragged_arc, arc_map_a, arc_map_b = _k2.intersect_dense_pruned( a_fsas=a_fsas.arcs, b_fsas=b_fsas.dense_fsa_vec, beam=beam, max_active_states=max_active_states, min_active_states=min_active_states) out_fsa[0] = Fsa.from_ragged_arc(ragged_arc) for name, a_value in a_fsas.named_tensor_attr(): if name == 'scores': continue value = _k2.index_select(a_value, arc_map_a) setattr(out_fsa[0], name, value) for name, a_value in a_fsas.named_non_tensor_attr(): setattr(out_fsa[0], name, a_value) ctx.arc_map_a = arc_map_a ctx.arc_map_b = arc_map_b ctx.save_for_backward(unused_scores_a, unused_scores_b) return out_fsa[0].scores