def arc_sort(fsa: Fsa) -> Fsa: '''Sort arcs of every state. Note: Arcs are sorted by labels first, and then by dest states. Caution: If the input `fsa` is already arc sorted, we return it directly. Otherwise, a new sorted fsa is returned. Args: fsa: The input FSA. Returns: The sorted FSA. It is the same as the input `fsa` if the input `fsa` is arc sorted. Otherwise, a new sorted fsa is returned and the input `fsa` is NOT modified. ''' if fsa.properties & fsa_properties.ARC_SORTED != 0: return fsa need_arc_map = True ragged_arc, arc_map = _k2.arc_sort(fsa.arcs, need_arc_map=need_arc_map) out_fsa = Fsa(ragged_arc) for name, value in fsa.named_tensor_attr(): setattr(out_fsa, name, index_attr(value, arc_map)) for name, value in fsa.named_non_tensor_attr(): setattr(out_fsa, name, value) return out_fsa
def arc_sort(fsa: Fsa) -> Fsa: '''Sort arcs of every state. Note: Arcs are sorted by labels first, and then by dest states. Caution: If the input `fsa` is already arc sorted, we return it directly. Otherwise, a new sorted fsa is returned. Args: fsa: The input FSA. Returns: The sorted FSA. It is the same as the input `fsa` if the input `fsa` is arc sorted. Otherwise, a new sorted fsa is returned and the input `fsa` is NOT modified. ''' if fsa.properties & fsa_properties.ARC_SORTED != 0: return fsa need_arc_map = True ragged_arc, arc_map = _k2.arc_sort(fsa.arcs, need_arc_map=need_arc_map) out_fsa = k2.utils.fsa_from_unary_function_tensor(fsa, ragged_arc, arc_map) return out_fsa
def arc_sort(fsa: Fsa) -> Fsa: '''Sort arcs of every state. Note: Arcs are sorted by labels first, and then by dest states. Caution: If the input ``fsa`` is already arc sorted, we return it directly. Otherwise, a new sorted fsa is returned. Args: fsa: The input FSA. Returns: The sorted FSA. It is the same as the input ``fsa`` if the input ``fsa`` is arc sorted. Otherwise, a new sorted fsa is returned and the input ``fsa`` is NOT modified. ''' properties = getattr(fsa, 'properties', None) if properties is not None and is_arc_sorted(properties): return fsa need_arc_map = True ragged_arc, arc_map = _k2.arc_sort(fsa.arcs, need_arc_map=need_arc_map) arc_map = arc_map.to(torch.int64) # required by index_select out_fsa = Fsa.from_ragged_arc(ragged_arc) for name, value in fsa.named_tensor_attr(): setattr(out_fsa, name, value.index_select(0, arc_map)) for name, value in fsa.named_non_tensor_attr(): setattr(out_fsa, name, value) return out_fsa
def arc_sort( fsa: Fsa, ret_arc_map: bool = False ) -> Union[Fsa, Tuple[Fsa, torch.Tensor]]: # noqa '''Sort arcs of every state. Note: Arcs are sorted by labels first, and then by dest states. Caution: If the input `fsa` is already arc sorted, we return it directly. Otherwise, a new sorted fsa is returned. Args: fsa: The input FSA. ret_arc_map: True to return an extra arc_map (a 1-D tensor with dtype being torch.int32). arc_map[i] is the arc index in the input `fsa` that corresponds to the i-th arc in the output Fsa. Returns: If ret_arc_map is False, return the sorted FSA. It is the same as the input `fsa` if the input `fsa` is arc sorted. Otherwise, a new sorted fsa is returned and the input `fsa` is NOT modified. If ret_arc_map is True, an extra arc map is also returned. ''' if fsa.properties & fsa_properties.ARC_SORTED != 0: if ret_arc_map: # in this case, arc_map is an identity map arc_map = torch.arange(fsa.num_arcs, dtype=torch.int32, device=fsa.device) return fsa, arc_map else: return fsa need_arc_map = True ragged_arc, arc_map = _k2.arc_sort(fsa.arcs, need_arc_map=need_arc_map) out_fsa = k2.utils.fsa_from_unary_function_tensor(fsa, ragged_arc, arc_map) if ret_arc_map: return out_fsa, arc_map else: return out_fsa
def test(self): for device in self.devices: s = ''' 0 1 2 10 0 1 1 20 1 2 -1 30 2 ''' src = k2.Fsa.from_str(s).to(device).requires_grad_(True) src.float_attr = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32, requires_grad=True, device=device) src.int_attr = torch.tensor([1, 2, 3], dtype=torch.int32, device=device) src.ragged_attr = k2.RaggedTensor([[1, 2, 3], [5, 6], []]).to(device) src.attr1 = 'src' src.attr2 = 'fsa' ragged_arc, arc_map = _k2.arc_sort(src.arcs, need_arc_map=True) dest = k2.utils.fsa_from_unary_function_tensor( src, ragged_arc, arc_map) assert torch.allclose( dest.float_attr, torch.tensor([0.2, 0.1, 0.3], dtype=torch.float32, device=device)) assert torch.all( torch.eq( dest.scores, torch.tensor([20, 10, 30], dtype=torch.float32, device=device))) assert torch.all( torch.eq( dest.int_attr, torch.tensor([2, 1, 3], dtype=torch.int32, device=device))) expected_ragged_attr = k2.RaggedTensor([[5, 6], [1, 2, 3], []]).to(device) assert dest.ragged_attr == expected_ragged_attr assert dest.attr1 == src.attr1 assert dest.attr2 == src.attr2 # now for autograd scale = torch.tensor([10, 20, 30], device=device) (dest.float_attr * scale).sum().backward() (dest.scores * scale).sum().backward() expected_grad = torch.tensor([20, 10, 30], dtype=torch.float32, device=device) assert torch.all(torch.eq(src.float_attr.grad, expected_grad)) assert torch.all(torch.eq(src.scores.grad, expected_grad))
def test_without_negative_1(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: s = ''' 0 1 2 10 0 1 1 20 1 2 -1 30 2 ''' src = k2.Fsa.from_str(s).to(device).requires_grad_(True) src.float_attr = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32, requires_grad=True, device=device) src.int_attr = torch.tensor([1, 2, 3], dtype=torch.int32, device=device) src.ragged_attr = k2.RaggedInt('[[1 2 3] [5 6] []]').to(device) src.attr1 = 'src' src.attr2 = 'fsa' ragged_arc, arc_map = _k2.arc_sort(src.arcs, need_arc_map=True) dest = k2.utils.fsa_from_unary_function_tensor( src, ragged_arc, arc_map) assert torch.allclose( dest.float_attr, torch.tensor([0.2, 0.1, 0.3], dtype=torch.float32, device=device)) assert torch.all( torch.eq( dest.scores, torch.tensor([20, 10, 30], dtype=torch.float32, device=device))) assert torch.all( torch.eq( dest.int_attr, torch.tensor([2, 1, 3], dtype=torch.int32, device=device))) expected_ragged_attr = k2.RaggedInt('[ [5 6] [1 2 3] []]') self.assertEqual(str(dest.ragged_attr), str(expected_ragged_attr)) assert dest.attr1 == src.attr1 assert dest.attr2 == src.attr2 # now for autograd scale = torch.tensor([10, 20, 30], device=device) (dest.float_attr * scale).sum().backward() (dest.scores * scale).sum().backward() expected_grad = torch.tensor([20, 10, 30], dtype=torch.float32, device=device) assert torch.all(torch.eq(src.float_attr.grad, expected_grad)) assert torch.all(torch.eq(src.scores.grad, expected_grad))