def forward(ctx, fsas: Fsa, out_fsa: List[Fsa], unused_fsas_scores: torch.Tensor) -> torch.Tensor: '''Compute the union of all fsas in a FsaVec. Args: fsas: The input FsaVec. Caution: We require that each fsa in the FsaVec is non-empty (i.e., with at least two states). out_fsa: A list containing one entry. Since this function can only return values of type `torch.Tensor`, we return the union result in the list. unused_fsas_scores: It is the same as `fsas.scores`, whose sole purpose is for autograd. It is not used in this function. ''' need_arc_map = True ragged_arc, arc_map = _k2.union(fsas.arcs, need_arc_map) out_fsa[0] = Fsa(ragged_arc) for name, value in fsas.named_tensor_attr(include_scores=False): value = k2.index(value, arc_map) setattr(out_fsa[0], name, value) for name, value in fsas.named_non_tensor_attr(): setattr(out_fsa[0], name, value) ctx.arc_map = arc_map ctx.save_for_backward(unused_fsas_scores) return out_fsa[0].scores # the return value will be discarded
def union(fsas: Fsa) -> Fsa: '''Compute the union of a FsaVec. Args: fsas: A FsaVec. That is, len(fsas.shape) == 3. Returns: A single Fsa that is the union of the input fsas. ''' # TODO(fangjun): change it to True once arc_map is implemented need_arc_map = False ragged_arc, _ = _k2.union(fsas.arcs, need_arc_map) out_fsa = Fsa.from_ragged_arc(ragged_arc) # TODO(fangjun): copy attr from the input fsas to out_fsa return out_fsa
def union(fsas: Fsa) -> Fsa: '''Compute the union of a FsaVec. Caution: We require that every fsa in fsas is non-empty, i.e., contains at least two states Args: fsas: A FsaVec. That is, len(fsas.shape) == 3. Returns: A single Fsa that is the union of the input fsas. ''' need_arc_map = True ragged_arc, arc_map = _k2.union(fsas.arcs, need_arc_map) out_fsa = k2.utils.fsa_from_unary_function_tensor(fsas, ragged_arc, arc_map) return out_fsa