示例#1
0
文件: fsa_algo.py 项目: OUC-lan/k2
def arc_sort(fsa: Fsa) -> Fsa:
    '''Sort arcs of every state.

    Note:
      Arcs are sorted by labels first, and then by dest states.

    Caution:
      If the input `fsa` is already arc sorted, we return it directly.
      Otherwise, a new sorted fsa is returned.

    Args:
      fsa:
        The input FSA.
    Returns:
      The sorted FSA. It is the same as the input `fsa` if the input
      `fsa` is arc sorted. Otherwise, a new sorted fsa is returned
      and the input `fsa` is NOT modified.
    '''
    if fsa.properties & fsa_properties.ARC_SORTED != 0:
        return fsa

    need_arc_map = True
    ragged_arc, arc_map = _k2.arc_sort(fsa.arcs, need_arc_map=need_arc_map)
    out_fsa = Fsa(ragged_arc)
    for name, value in fsa.named_tensor_attr():
        setattr(out_fsa, name, index_attr(value, arc_map))
    for name, value in fsa.named_non_tensor_attr():
        setattr(out_fsa, name, value)

    return out_fsa
示例#2
0
文件: fsa_algo.py 项目: yyht/k2
def arc_sort(fsa: Fsa) -> Fsa:
    '''Sort arcs of every state.

    Note:
      Arcs are sorted by labels first, and then by dest states.

    Caution:
      If the input `fsa` is already arc sorted, we return it directly.
      Otherwise, a new sorted fsa is returned.

    Args:
      fsa:
        The input FSA.
    Returns:
      The sorted FSA. It is the same as the input `fsa` if the input
      `fsa` is arc sorted. Otherwise, a new sorted fsa is returned
      and the input `fsa` is NOT modified.
    '''
    if fsa.properties & fsa_properties.ARC_SORTED != 0:
        return fsa

    need_arc_map = True
    ragged_arc, arc_map = _k2.arc_sort(fsa.arcs, need_arc_map=need_arc_map)

    out_fsa = k2.utils.fsa_from_unary_function_tensor(fsa, ragged_arc, arc_map)
    return out_fsa
示例#3
0
def arc_sort(fsa: Fsa) -> Fsa:
    '''Sort arcs of every state.

    Note:
      Arcs are sorted by labels first, and then by dest states.

    Caution:
      If the input ``fsa`` is already arc sorted, we return it directly.
      Otherwise, a new sorted fsa is returned.

    Args:
      fsa:
        The input FSA.
    Returns:
      The sorted FSA. It is the same as the input ``fsa`` if the input
      ``fsa`` is arc sorted. Otherwise, a new sorted fsa is returned
      and the input ``fsa`` is NOT modified.
    '''
    properties = getattr(fsa, 'properties', None)
    if properties is not None and is_arc_sorted(properties):
        return fsa

    need_arc_map = True
    ragged_arc, arc_map = _k2.arc_sort(fsa.arcs, need_arc_map=need_arc_map)
    arc_map = arc_map.to(torch.int64)  # required by index_select
    out_fsa = Fsa.from_ragged_arc(ragged_arc)
    for name, value in fsa.named_tensor_attr():
        setattr(out_fsa, name, value.index_select(0, arc_map))
    for name, value in fsa.named_non_tensor_attr():
        setattr(out_fsa, name, value)

    return out_fsa
示例#4
0
文件: fsa_algo.py 项目: entn-at/k2
def arc_sort(
        fsa: Fsa,
        ret_arc_map: bool = False
) -> Union[Fsa, Tuple[Fsa, torch.Tensor]]:  # noqa
    '''Sort arcs of every state.

    Note:
      Arcs are sorted by labels first, and then by dest states.

    Caution:
      If the input `fsa` is already arc sorted, we return it directly.
      Otherwise, a new sorted fsa is returned.

    Args:
      fsa:
        The input FSA.
      ret_arc_map:
        True to return an extra arc_map (a 1-D tensor with dtype being
        torch.int32). arc_map[i] is the arc index in the input `fsa` that
        corresponds to the i-th arc in the output Fsa.
    Returns:
      If ret_arc_map is False, return the sorted FSA. It is the same as the
      input `fsa` if the input `fsa` is arc sorted. Otherwise, a new sorted
      fsa is returned and the input `fsa` is NOT modified.
      If ret_arc_map is True, an extra arc map is also returned.
    '''
    if fsa.properties & fsa_properties.ARC_SORTED != 0:
        if ret_arc_map:
            # in this case, arc_map is an identity map
            arc_map = torch.arange(fsa.num_arcs,
                                   dtype=torch.int32,
                                   device=fsa.device)
            return fsa, arc_map
        else:
            return fsa

    need_arc_map = True
    ragged_arc, arc_map = _k2.arc_sort(fsa.arcs, need_arc_map=need_arc_map)

    out_fsa = k2.utils.fsa_from_unary_function_tensor(fsa, ragged_arc, arc_map)
    if ret_arc_map:
        return out_fsa, arc_map
    else:
        return out_fsa
示例#5
0
    def test(self):
        for device in self.devices:
            s = '''
                0 1 2 10
                0 1 1 20
                1 2 -1 30
                2
            '''
            src = k2.Fsa.from_str(s).to(device).requires_grad_(True)
            src.float_attr = torch.tensor([0.1, 0.2, 0.3],
                                          dtype=torch.float32,
                                          requires_grad=True,
                                          device=device)
            src.int_attr = torch.tensor([1, 2, 3],
                                        dtype=torch.int32,
                                        device=device)
            src.ragged_attr = k2.RaggedTensor([[1, 2, 3], [5, 6],
                                               []]).to(device)
            src.attr1 = 'src'
            src.attr2 = 'fsa'

            ragged_arc, arc_map = _k2.arc_sort(src.arcs, need_arc_map=True)

            dest = k2.utils.fsa_from_unary_function_tensor(
                src, ragged_arc, arc_map)

            assert torch.allclose(
                dest.float_attr,
                torch.tensor([0.2, 0.1, 0.3],
                             dtype=torch.float32,
                             device=device))

            assert torch.all(
                torch.eq(
                    dest.scores,
                    torch.tensor([20, 10, 30],
                                 dtype=torch.float32,
                                 device=device)))

            assert torch.all(
                torch.eq(
                    dest.int_attr,
                    torch.tensor([2, 1, 3], dtype=torch.int32, device=device)))

            expected_ragged_attr = k2.RaggedTensor([[5, 6], [1, 2, 3],
                                                    []]).to(device)
            assert dest.ragged_attr == expected_ragged_attr

            assert dest.attr1 == src.attr1
            assert dest.attr2 == src.attr2

            # now for autograd
            scale = torch.tensor([10, 20, 30], device=device)
            (dest.float_attr * scale).sum().backward()
            (dest.scores * scale).sum().backward()

            expected_grad = torch.tensor([20, 10, 30],
                                         dtype=torch.float32,
                                         device=device)

            assert torch.all(torch.eq(src.float_attr.grad, expected_grad))

            assert torch.all(torch.eq(src.scores.grad, expected_grad))
示例#6
0
    def test_without_negative_1(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))

        for device in devices:
            s = '''
                0 1 2 10
                0 1 1 20
                1 2 -1 30
                2
            '''
            src = k2.Fsa.from_str(s).to(device).requires_grad_(True)
            src.float_attr = torch.tensor([0.1, 0.2, 0.3],
                                          dtype=torch.float32,
                                          requires_grad=True,
                                          device=device)
            src.int_attr = torch.tensor([1, 2, 3],
                                        dtype=torch.int32,
                                        device=device)
            src.ragged_attr = k2.RaggedInt('[[1 2 3] [5 6] []]').to(device)
            src.attr1 = 'src'
            src.attr2 = 'fsa'

            ragged_arc, arc_map = _k2.arc_sort(src.arcs, need_arc_map=True)

            dest = k2.utils.fsa_from_unary_function_tensor(
                src, ragged_arc, arc_map)

            assert torch.allclose(
                dest.float_attr,
                torch.tensor([0.2, 0.1, 0.3],
                             dtype=torch.float32,
                             device=device))

            assert torch.all(
                torch.eq(
                    dest.scores,
                    torch.tensor([20, 10, 30],
                                 dtype=torch.float32,
                                 device=device)))

            assert torch.all(
                torch.eq(
                    dest.int_attr,
                    torch.tensor([2, 1, 3], dtype=torch.int32, device=device)))

            expected_ragged_attr = k2.RaggedInt('[ [5 6] [1 2 3] []]')
            self.assertEqual(str(dest.ragged_attr), str(expected_ragged_attr))

            assert dest.attr1 == src.attr1
            assert dest.attr2 == src.attr2

            # now for autograd
            scale = torch.tensor([10, 20, 30], device=device)
            (dest.float_attr * scale).sum().backward()
            (dest.scores * scale).sum().backward()

            expected_grad = torch.tensor([20, 10, 30],
                                         dtype=torch.float32,
                                         device=device)

            assert torch.all(torch.eq(src.float_attr.grad, expected_grad))

            assert torch.all(torch.eq(src.scores.grad, expected_grad))