Beispiel #1
0
    def test_1d(self):
        for device in self.devices:
            row_splits1 = torch.tensor([0, 3, 5, 6, 6, 9],
                                       dtype=torch.int32,
                                       device=device)
            # we don't need to call shape2.to(device) here as shape2
            # will be on the same device as row_splits
            shape2 = k2.ragged.create_ragged_shape2(row_splits1, None, 9)
            values = torch.tensor([1, 0, 4, 2, 3, 0, 4, 5, 2],
                                  dtype=torch.int32,
                                  device=device)
            ragged2 = k2.RaggedTensor(shape2, values)

            # contiguous
            src = torch.tensor([0, 2, 0, 10, 0, -1],
                               dtype=torch.int32,
                               device=device)
            ans = k2.simple_ragged_index_select(src, ragged2)
            self.assertEqual(ans.dtype, src.dtype)
            self.assertEqual(ans.numel(), shape2.dim0)
            expected = torch.tensor([2, 10, 0, 0, -1],
                                    dtype=torch.int32,
                                    device=device)
            self.assertTrue(torch.allclose(ans, expected))

            # non-contiguous
            src = src.expand(3, -1).t().flatten()[::3]
            self.assertFalse(src.is_contiguous())
            self.assertEqual(src.stride(0), 3)
            ans = k2.simple_ragged_index_select(src, ragged2)
            self.assertEqual(ans.dtype, src.dtype)
            self.assertEqual(ans.numel(), shape2.dim0)
            self.assertTrue(ans.is_contiguous())
            self.assertEqual(ans.stride(0), 1)
            expected = torch.tensor([2, 10, 0, 0, -1],
                                    dtype=torch.int32,
                                    device=device)
            self.assertTrue(torch.allclose(ans, expected))
Beispiel #2
0
    def test_with_negative_1(self):
        for device in self.devices:
            src = torch.tensor([0, 1, 2, 3],
                               dtype=torch.float32,
                               requires_grad=True,
                               device=device)
            indexes = k2.RaggedTensor([[1, 2, -1], [0, 3], [-1],
                                       [0, 2, 3, 1, 3], []]).to(device)
            ans = k2.ragged.index_and_sum(src, indexes)
            expected = torch.tensor([1 + 2, 0 + 3, 0, 0 + 2 + 3 + 1 + 3,
                                     0]).to(src)
            assert torch.allclose(ans, expected)

            # now for autograd
            scale = torch.tensor([10, 20, 30, 40, 50]).to(device)
            (ans * scale).sum().backward()
            expected_grad = torch.empty_like(src.grad)
            expected_grad[0] = scale[1] + scale[3]
            expected_grad[1] = scale[0] + scale[3]
            expected_grad[2] = scale[0] + scale[3]
            expected_grad[3] = scale[1] + scale[3] * 2

            assert torch.allclose(src.grad, expected_grad)
Beispiel #3
0
 def test_aux_as_ragged(self):
     s = '''
         0 1 1 0
         0 1 0 0
         0 3 2 0
         1 2 3 0
         1 3 4 0
         2 1 5 0
         2 5 -1 0
         3 1 6 0
         4 5 -1 0
         5
     '''
     fsa = k2.Fsa.from_str(s)
     assert fsa.device.type == 'cpu'
     aux_row_splits = torch.tensor([0, 2, 3, 3, 6, 6, 7, 8, 10, 11],
                                   dtype=torch.int32)
     aux_shape = k2.ragged.create_ragged_shape2(aux_row_splits, None, 11)
     aux_values = torch.tensor([1, 2, 3, 5, 6, 7, 8, -1, 9, 10, -1],
                               dtype=torch.int32)
     fsa.aux_labels = k2.RaggedTensor(aux_shape, aux_values)
     dest = k2.invert(fsa)
     print(dest)  # will print aux_labels as well
Beispiel #4
0
def ctc_graph(symbols: Union[List[List[int]], k2.RaggedTensor],
              modified: bool = False,
              device: Optional[Union[torch.device, str]] = "cpu") -> Fsa:
    '''Construct ctc graphs from symbols.

    Note:
      The scores of arcs in the returned FSA are all 0.

    Args:
      symbols:
        It can be one of the following types:

            - A list of list-of-integers, e..g, `[ [1, 2], [1, 2, 3] ]`
            - An instance of :class:`k2.RaggedTensor`.
              Must have `num_axes == 2`.

      standard:
        Option to specify the type of CTC topology: "standard" or "simplified",
        where the "standard" one makes the blank mandatory between a pair of
        identical symbols. Default True.
      device:
        Optional. It can be either a string (e.g., 'cpu', 'cuda:0') or a
        torch.device.
        By default, the returned FSA is on CPU.
        If `symbols` is an instance of :class:`k2.RaggedTensor`, the returned
        FSA will on the same device as `k2.RaggedTensor`.

    Returns:
        An FsaVec containing the returned ctc graphs, with "Dim0()" the same as
        "len(symbols)"(List[List[int]]) or "dim0"(k2.RaggedTensor)
    '''
    if not isinstance(symbols, k2.RaggedTensor):
        symbols = k2.RaggedTensor(symbols, device=device)

    ragged_arc, aux_labels = _k2.ctc_graph(symbols, modified)
    fsa = Fsa(ragged_arc, aux_labels=aux_labels)
    return fsa
Beispiel #5
0
    def test(self):
        for device in self.devices:
            fsa1 = k2.ctc_topo(5, device=device)
            fsa1.attr1 = torch.tensor([1] * fsa1.num_arcs, device=device)

            stream1 = k2.RnntDecodingStream(fsa1)

            fsa2 = k2.trivial_graph(3, device=device)
            fsa2.attr1 = torch.tensor([2] * fsa2.num_arcs, device=device)
            fsa2.attr2 = torch.tensor([22] * fsa2.num_arcs, device=device)

            stream2 = k2.RnntDecodingStream(fsa2)

            fsa3 = k2.ctc_topo(3, modified=True, device=device)
            fsa3.attr3 = k2.RaggedTensor(
                torch.ones((fsa3.num_arcs, 2), dtype=torch.int32, device=device)
                * 3
            )

            stream3 = k2.RnntDecodingStream(fsa3)

            config = k2.RnntDecodingConfig(10, 2, 3.0, 3, 3)
            streams = k2.RnntDecodingStreams(
                [stream1, stream2, stream3], config
            )

            for i in range(5):
                shape, context = streams.get_contexts()
                logprobs = torch.randn(
                    (context.shape[0], 10), dtype=torch.float32, device=device
                )
                streams.advance(logprobs)

            streams.terminate_and_flush_to_streams()
            ofsa = streams.format_output([3, 4, 5])
            print(ofsa)
Beispiel #6
0
    def test_sum_per_sublist(self):
        s = '''
            0 1 1 0.
            0 1 2 0.
            0 1 3 0.
            1 2 4 0.
            1 2 5 0.
            2 3 -1 0.
            3
        '''
        for device in self.devices:
            fsa = k2.Fsa.from_str(s).to(device)
            scores = torch.randn_like(fsa.scores)
            fsa.set_scores_stochastic_(scores)
            ragged = k2.RaggedTensor(fsa.arcs.shape(), fsa.scores.exp())
            normalized_scores = ragged.sum()
            assert normalized_scores.numel() == fsa.arcs.dim0()

            assert torch.allclose(
                normalized_scores[:-1],
                torch.ones(normalized_scores.numel() - 1, device=device))

            # the final state has no leaving arcs
            assert normalized_scores[-1].item() == 0
Beispiel #7
0
    def test(self):
        for device in self.devices:
            s = '''
                0 1 2 10
                0 1 1 20
                1 2 -1 30
                2
            '''
            src = k2.Fsa.from_str(s).to(device).requires_grad_(True)
            src.float_attr = torch.tensor([0.1, 0.2, 0.3],
                                          dtype=torch.float32,
                                          requires_grad=True,
                                          device=device)
            src.int_attr = torch.tensor([1, 2, 3],
                                        dtype=torch.int32,
                                        device=device)
            src.ragged_attr = k2.RaggedTensor([[1, 2, 3], [5, 6],
                                               []]).to(device)
            src.attr1 = 'src'
            src.attr2 = 'fsa'

            ragged_arc, arc_map = _k2.arc_sort(src.arcs, need_arc_map=True)

            dest = k2.utils.fsa_from_unary_function_tensor(
                src, ragged_arc, arc_map)

            assert torch.allclose(
                dest.float_attr,
                torch.tensor([0.2, 0.1, 0.3],
                             dtype=torch.float32,
                             device=device))

            assert torch.all(
                torch.eq(
                    dest.scores,
                    torch.tensor([20, 10, 30],
                                 dtype=torch.float32,
                                 device=device)))

            assert torch.all(
                torch.eq(
                    dest.int_attr,
                    torch.tensor([2, 1, 3], dtype=torch.int32, device=device)))

            expected_ragged_attr = k2.RaggedTensor([[5, 6], [1, 2, 3],
                                                    []]).to(device)
            assert dest.ragged_attr == expected_ragged_attr

            assert dest.attr1 == src.attr1
            assert dest.attr2 == src.attr2

            # now for autograd
            scale = torch.tensor([10, 20, 30], device=device)
            (dest.float_attr * scale).sum().backward()
            (dest.scores * scale).sum().backward()

            expected_grad = torch.tensor([20, 10, 30],
                                         dtype=torch.float32,
                                         device=device)

            assert torch.all(torch.eq(src.float_attr.grad, expected_grad))

            assert torch.all(torch.eq(src.scores.grad, expected_grad))
    def test_single_fsa(self):
        for device in self.devices:
            # See https://git.io/JY7r4
            s = '''
                0 1 0 0.1
                0 2 0 0.2
                0 0 0 0.3
                1 1 0 0.4
                1 2 0 0.5
                2 3 -1 0.6
                3
            '''
            src = k2.Fsa.from_str(s).to(device).requires_grad_(True)
            scores_copy = src.scores.detach().clone().requires_grad_(True)

            src.attr1 = "hello"
            src.attr2 = "k2"
            float_attr = torch.tensor([0.1, 0.2, 0.3, 4, 5, 6],
                                      dtype=torch.float32,
                                      requires_grad=True,
                                      device=device)

            src.float_attr = float_attr.detach().clone().requires_grad_(True)
            src.int_attr = torch.tensor([1, 2, 3, 4, 5, 6],
                                        dtype=torch.int32,
                                        device=device)
            src.ragged_attr = k2.RaggedTensor([[10, 20], [30, 40,
                                                          50], [60, 70], [80],
                                               [], [0]]).to(device)

            dest = k2.remove_epsilon_self_loops(src)
            # arc map is [0, 1, 4, 5]

            # See https://git.io/JY7oC
            expected_fsa = k2.Fsa.from_str('''
                0 1 0 0.1
                0 2 0 0.2
                1 2 0 0.5
                2 3 -1 0.6
                3
            ''')
            assert k2.to_str_simple(dest) == k2.to_str_simple(
                expected_fsa), f'{str(dest)}\n{str(expected_fsa)}'

            assert dest.attr1 == src.attr1
            assert dest.attr2 == src.attr2

            expected_int_attr = torch.tensor([1, 2, 5, 6],
                                             dtype=torch.int32,
                                             device=device)
            assert torch.all(torch.eq(dest.int_attr, expected_int_attr))

            expected_ragged_attr = k2.RaggedTensor([[10, 20], [30, 40, 50], [],
                                                    [0]]).to(device)
            assert dest.ragged_attr == expected_ragged_attr

            expected_float_attr = torch.empty_like(dest.float_attr)
            expected_float_attr[0] = float_attr[0]
            expected_float_attr[1] = float_attr[1]
            expected_float_attr[2] = float_attr[4]
            expected_float_attr[3] = float_attr[5]

            assert torch.all(torch.eq(dest.float_attr, expected_float_attr))

            expected_scores = torch.empty_like(dest.scores)
            expected_scores[0] = scores_copy[0]
            expected_scores[1] = scores_copy[1]
            expected_scores[2] = scores_copy[4]
            expected_scores[3] = scores_copy[5]

            assert torch.all(torch.eq(dest.scores, expected_scores))

            scale = torch.tensor([10, 20, 30, 40]).to(float_attr)

            (dest.float_attr * scale).sum().backward()
            (expected_float_attr * scale).sum().backward()
            assert torch.all(torch.eq(src.float_attr.grad, float_attr.grad))

            (dest.scores * scale).sum().backward()
            (expected_scores * scale).sum().backward()
            assert torch.all(torch.eq(src.scores.grad, scores_copy.grad))
    def test_fsa_vec(self):
        for device in self.devices:
            # See https://git.io/JY7r4
            s = '''
                0 1 0 0.1
                0 2 0 0.2
                0 0 0 0.3
                1 1 0 0.4
                1 2 0 0.5
                2 3 -1 0.6
                3
            '''
            fsa1 = k2.Fsa.from_str(s).to(device).requires_grad_(True)
            scores_copy1 = fsa1.scores.detach().clone().requires_grad_(True)
            fsa1.attr1 = "hello"
            float_attr1 = torch.tensor([0.1, 0.2, 0.3, 4, 5, 6],
                                       dtype=torch.float32,
                                       requires_grad=True,
                                       device=device)
            fsa1.float_attr = float_attr1
            fsa1.int_attr = torch.tensor([1, 2, 3, 4, 5, 6],
                                         dtype=torch.int32,
                                         device=device)
            fsa1.ragged_attr = k2.RaggedTensor([[10, 20], [30, 40,
                                                           50], [60, 70], [80],
                                                [], [0]]).to(device)

            fsa2 = k2.Fsa.from_str(s).to(device).requires_grad_(True)
            scores_copy2 = fsa2.scores.detach().clone().requires_grad_(True)
            fsa2.attr2 = "k2"
            float_attr2 = torch.tensor([1, 2, 3, 40, 50, 60],
                                       dtype=torch.float32,
                                       requires_grad=True,
                                       device=device)
            fsa2.float_attr = float_attr2
            fsa2.int_attr = torch.tensor([10, 20, 30, 4, 5, 6],
                                         dtype=torch.int32,
                                         device=device)
            fsa2.ragged_attr = k2.RaggedTensor([[100, 200], [300, 400, 500],
                                                [600, 700], [800], [22],
                                                [33, 55]]).to(device)

            src = k2.create_fsa_vec([fsa1, fsa2])

            dest = k2.remove_epsilon_self_loops(src)
            # arc map is[0, 1, 4, 5, 6, 7, 10, 11]

            # See https://git.io/JY7oC
            expected_fsa = k2.Fsa.from_str('''
                0 1 0 0.1
                0 2 0 0.2
                1 2 0 0.5
                2 3 -1 0.6
                3
            ''')
            assert k2.to_str_simple(dest[0]) == k2.to_str_simple(expected_fsa)
            assert k2.to_str_simple(dest[1]) == k2.to_str_simple(expected_fsa)

            assert dest.attr1 == fsa1.attr1
            assert dest.attr2 == fsa2.attr2

            expected_int_attr = torch.tensor([1, 2, 5, 6, 10, 20, 5, 6],
                                             dtype=torch.int32,
                                             device=device)
            assert torch.all(torch.eq(dest.int_attr, expected_int_attr))

            expected_ragged_attr = k2.RaggedTensor([[10, 20], [30, 40, 50], [],
                                                    [0], [100, 200],
                                                    [300, 400, 500], [22],
                                                    [33, 55]]).to(device)
            assert dest.ragged_attr == expected_ragged_attr

            expected_float_attr = torch.empty_like(dest.float_attr)
            expected_float_attr[0] = float_attr1[0]
            expected_float_attr[1] = float_attr1[1]
            expected_float_attr[2] = float_attr1[4]
            expected_float_attr[3] = float_attr1[5]
            expected_float_attr[4] = float_attr2[0]
            expected_float_attr[5] = float_attr2[1]
            expected_float_attr[6] = float_attr2[4]
            expected_float_attr[7] = float_attr2[5]

            assert torch.all(torch.eq(dest.float_attr, expected_float_attr))

            expected_scores = torch.empty_like(dest.scores)
            expected_scores[0] = scores_copy1[0]
            expected_scores[1] = scores_copy1[1]
            expected_scores[2] = scores_copy1[4]
            expected_scores[3] = scores_copy1[5]
            expected_scores[4] = scores_copy2[0]
            expected_scores[5] = scores_copy2[1]
            expected_scores[6] = scores_copy2[4]
            expected_scores[7] = scores_copy2[5]

            assert torch.all(torch.eq(dest.scores, expected_scores))

            scale = torch.tensor([10, 20, 30, 40, 50, 60, 70,
                                  80]).to(dest.float_attr)

            (dest.float_attr * scale).sum().backward()
            (expected_float_attr * scale).sum().backward()

            assert torch.all(torch.eq(fsa1.float_attr.grad, float_attr1.grad))
            assert torch.all(torch.eq(fsa2.float_attr.grad, float_attr2.grad))

            (dest.scores * scale).sum().backward()
            (expected_scores * scale).sum().backward()

            assert torch.all(torch.eq(fsa1.scores.grad, scores_copy1.grad))
            assert torch.all(torch.eq(fsa2.scores.grad, scores_copy2.grad))
Beispiel #10
0
import k2
s = '''
0 1 2 0.1
1 2 -1 0.2
2
'''
fsa = k2.Fsa.from_str(s)
fsa.aux_labels = k2.RaggedTensor('[ [10 20] [-1] ]')
inverted_fsa = k2.invert(fsa)
fsa.draw('before_invert_aux.svg',
         title='before invert with ragged tensors as aux_labels')
inverted_fsa.draw('after_invert_aux.svg', title='after invert')
Beispiel #11
0
    def test(self):
        s0 = '''
            0 1 1 0.1
            0 2 2 0.2
            1 2 3 0.3
            1 3 -1 0.4
            2 3 -1 0.5
            2 1 5 0.55
            3
        '''
        s1 = '''
            0 1 -1 0.6
            1
        '''
        s2 = '''
            0 1 6 0.7
            1 0 7 0.8
            1 0 8 0.9
            1 2 -1 1.0
            2
        '''
        for device in self.devices:
            fsa0 = k2.Fsa.from_str(s0)
            fsa1 = k2.Fsa.from_str(s1)
            fsa2 = k2.Fsa.from_str(s2)

            fsa0.tensor_attr = torch.tensor([1, 2, 3, 4, 5, 6],
                                            dtype=torch.int32,
                                            device=device)
            fsa0.ragged_tensor_attr = k2.RaggedTensor(
                fsa0.tensor_attr.unsqueeze(-1))

            fsa1.tensor_attr = torch.tensor([7],
                                            dtype=torch.int32,
                                            device=device)

            fsa1.ragged_tensor_attr = k2.RaggedTensor(
                fsa1.tensor_attr.unsqueeze(-1))

            fsa2.tensor_attr = torch.tensor([8, 9, 10, 11],
                                            dtype=torch.int32,
                                            device=device)

            fsa2.ragged_tensor_attr = k2.RaggedTensor(
                fsa2.tensor_attr.unsqueeze(-1))

            fsa_vec = k2.create_fsa_vec([fsa0, fsa1, fsa2]).to(device)

            fsa = k2.union(fsa_vec)

            expected_tensor_attr = torch.tensor(
                [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
                 11]).to(fsa.tensor_attr)
            assert torch.all(torch.eq(fsa.tensor_attr, expected_tensor_attr))

            expected_ragged_tensor_attr = k2.RaggedTensor(
                expected_tensor_attr.unsqueeze(-1)).remove_values_eq(0)
            assert str(expected_ragged_tensor_attr) == str(
                fsa.ragged_tensor_attr)

            assert torch.allclose(
                fsa.arcs.values()[:, :3],
                torch.tensor([
                    [0, 1, 0],  # fsa 0
                    [0, 4, 0],  # fsa 1
                    [0, 5, 0],  # fsa 2
                    # now for fsa0
                    [1, 2, 1],
                    [1, 3, 2],
                    [2, 3, 3],
                    [2, 7, -1],
                    [3, 7, -1],
                    [3, 2, 5],
                    # fsa1
                    [4, 7, -1],
                    # fsa2
                    [5, 6, 6],
                    [6, 5, 7],
                    [6, 5, 8],
                    [6, 7, -1]
                ]).to(torch.int32).to(device))
            assert torch.allclose(
                fsa.scores,
                torch.tensor([
                    0., 0., 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.55, 0.6, 0.7, 0.8,
                    0.9, 1.0
                ]).to(device))
Beispiel #12
0
    def test_autograd_remove_epsilon_and_add_self_loops(self):
        if not torch.cuda.is_available():
            return

        if not k2.with_cuda:
            return

        devices = [torch.device('cuda', 0)]
        if torch.cuda.device_count() > 1:
            torch.cuda.set_device(1)
            devices.append(torch.device('cuda', 1))

        s = '''
            0 1 0 0.1
            0 1 1 0.2
            1 2 -1 0.3
            2
        '''
        for device in devices:
            src = k2.Fsa.from_str(s).to(device).requires_grad_(True)
            scores_copy = src.scores.detach().clone().requires_grad_(True)

            src.attr1 = "hello"
            src.attr2 = "k2"
            float_attr = torch.tensor([0.1, 0.2, 0.3],
                                      dtype=torch.float32,
                                      requires_grad=True,
                                      device=device)

            src.float_attr = float_attr.detach().clone().requires_grad_(True)
            src.int_attr = torch.tensor([1, 2, 3],
                                        dtype=torch.int32,
                                        device=device)
            src.ragged_attr = k2.RaggedTensor([[10, 20], [30, 40, 50],
                                               [60, 70]]).to(device)

            dest = k2.remove_epsilon_and_add_self_loops(src)
            # without add_self_loops, the arc map is [[1] [0 2] [2]]
            # with add_self_loops, the arc map is [[] [1] [0 2] [] [2]]

            assert dest.attr1 == src.attr1
            assert dest.attr2 == src.attr2

            expected_int_attr = k2.RaggedTensor([[], [2], [1, 3], [],
                                                 [3]]).to(device)
            assert dest.int_attr == expected_int_attr

            expected_ragged_attr = k2.RaggedTensor([[], [30, 40, 50],
                                                    [10, 20, 60, 70], [],
                                                    [60, 70]]).to(device)
            assert dest.ragged_attr == expected_ragged_attr

            expected_float_attr = torch.empty_like(dest.float_attr)
            expected_float_attr[0] = 0
            expected_float_attr[1] = float_attr[1]
            expected_float_attr[2] = float_attr[0] + float_attr[2]
            expected_float_attr[3] = 0
            expected_float_attr[4] = float_attr[2]

            assert torch.all(torch.eq(dest.float_attr, expected_float_attr))

            expected_scores = torch.empty_like(dest.scores)
            expected_scores[0] = 0
            expected_scores[1] = scores_copy[1]
            expected_scores[2] = scores_copy[0] + scores_copy[2]
            expected_scores[3] = 0
            expected_scores[4] = scores_copy[2]

            assert torch.all(torch.eq(dest.scores, expected_scores))

            scale = torch.tensor([10, 20, 30, 40, 50]).to(float_attr)

            (dest.float_attr * scale).sum().backward()
            (expected_float_attr * scale).sum().backward()
            assert torch.all(torch.eq(src.float_attr.grad, float_attr.grad))

            (dest.scores * scale).sum().backward()
            (expected_scores * scale).sum().backward()
            assert torch.all(torch.eq(src.scores.grad, scores_copy.grad))
    def test_without_empty_list(self):
        for device in self.devices:
            s = '''
                0 1 0 0
                0 1 1 0
                1 2 -1 0
                2
            '''
            scores = torch.tensor([1, 2, 3],
                                  dtype=torch.float32,
                                  device=device,
                                  requires_grad=True)
            scores_copy = scores.detach().clone().requires_grad_(True)
            src = k2.Fsa.from_str(s).to(device)
            src.scores = scores
            # see https://git.io/Jufpl
            src.attr1 = "hello"
            src.attr2 = "k2"
            float_attr = torch.tensor([0.1, 0.2, 0.3],
                                      dtype=torch.float32,
                                      requires_grad=True,
                                      device=device)

            src.float_attr = float_attr.detach().clone().requires_grad_(True)
            src.int_attr = torch.tensor([1, 2, 3],
                                        dtype=torch.int32,
                                        device=device)
            src.ragged_attr = k2.RaggedTensor([[10, 20], [30, 40, 50],
                                               [60, 70]]).to(device)

            ragged_arc, arc_map = _k2.remove_epsilon(src.arcs, src.properties)
            # see https://git.io/Jufpe
            dest = k2.utils.fsa_from_unary_function_ragged(
                src, ragged_arc, arc_map)
            assert dest.attr1 == src.attr1
            assert dest.attr2 == src.attr2

            expected_arc_map = k2.RaggedTensor([[1], [0, 2], [2]]).to(device)
            assert arc_map == expected_arc_map

            expected_int_attr = k2.RaggedTensor([[2], [1, 3], [3]]).to(device)
            assert dest.int_attr == expected_int_attr

            expected_ragged_attr = k2.RaggedTensor([[30, 40, 50],
                                                    [10, 20, 60, 70],
                                                    [60, 70]]).to(device)
            assert dest.ragged_attr == expected_ragged_attr

            expected_float_attr = torch.empty_like(dest.float_attr)
            expected_float_attr[0] = float_attr[1]
            expected_float_attr[1] = float_attr[0] + float_attr[2]
            expected_float_attr[2] = float_attr[2]

            assert torch.all(torch.eq(dest.float_attr, expected_float_attr))

            expected_scores = torch.empty_like(dest.scores)
            expected_scores[0] = scores_copy[1]
            expected_scores[1] = scores_copy[0] + scores_copy[2]
            expected_scores[2] = scores_copy[2]

            assert torch.all(torch.eq(dest.scores, expected_scores))

            scale = torch.tensor([10, 20, 30]).to(float_attr)

            (dest.float_attr * scale).sum().backward()
            (expected_float_attr * scale).sum().backward()
            assert torch.all(torch.eq(src.float_attr.grad, float_attr.grad))

            (dest.scores * scale).sum().backward()
            (expected_scores * scale).sum().backward()
            assert torch.all(torch.eq(scores.grad, scores_copy.grad))
Beispiel #14
0
    def test(self):
        s = '''
            0 1 2 0.1
            0 1 1 0.2
            1 2 -1 0.3
            2
        '''
        for device in self.devices:
            src = k2.Fsa.from_str(s).to(device)
            src.requires_grad_(True)

            scores_copy = src.scores.detach().clone().requires_grad_(True)

            src.attr1 = "hello"
            src.attr2 = "k2"
            float_attr = torch.tensor([0.1, 0.2, 0.3],
                                      dtype=torch.float32,
                                      requires_grad=True,
                                      device=device)
            src.float_attr = float_attr.detach().clone().requires_grad_(True)
            src.int_attr = torch.tensor([1, 2, 3],
                                        dtype=torch.int32,
                                        device=device)
            src.ragged_attr = k2.RaggedTensor([[10, 20], [30, 40, 50],
                                               [60, 70]]).to(device)

            dest, arc_map = k2.arc_sort(src, ret_arc_map=True)

            assert dest.attr1 == src.attr1
            assert dest.attr2 == src.attr2

            expected_arc_map = torch.tensor([1, 0, 2],
                                            dtype=torch.int32,
                                            device=device)
            assert torch.all(torch.eq(arc_map, expected_arc_map))

            actual_str = k2.to_str_simple(dest)
            expected_str = '\n'.join(
                ['0 1 1 0.2', '0 1 2 0.1', '1 2 -1 0.3', '2'])
            assert actual_str.strip() == expected_str

            expected_int_attr = torch.tensor([2, 1, 3],
                                             dtype=torch.int32,
                                             device=device)
            assert torch.all(torch.eq(dest.int_attr, expected_int_attr))

            expected_ragged_attr = k2.RaggedTensor([[30, 40, 50], [10, 20],
                                                    [60, 70]]).to(device)
            assert dest.ragged_attr == expected_ragged_attr

            expected_float_attr = torch.empty_like(dest.float_attr)
            expected_float_attr[0] = float_attr[1]
            expected_float_attr[1] = float_attr[0]
            expected_float_attr[2] = float_attr[2]

            assert torch.all(torch.eq(dest.float_attr, expected_float_attr))

            expected_scores = torch.empty_like(dest.scores)
            expected_scores[0] = scores_copy[1]
            expected_scores[1] = scores_copy[0]
            expected_scores[2] = scores_copy[2]

            assert torch.all(torch.eq(dest.scores, expected_scores))

            scale = torch.tensor([10, 20, 30]).to(float_attr)

            (dest.float_attr * scale).sum().backward()
            (expected_float_attr * scale).sum().backward()
            assert torch.all(torch.eq(src.float_attr.grad, float_attr.grad))

            (dest.scores * scale).sum().backward()
            (expected_scores * scale).sum().backward()
            assert torch.all(torch.eq(src.scores.grad, scores_copy.grad))