コード例 #1
0
    def test_acceptor_from_tensor(self):
        fsa_tensor = torch.tensor([[0, 1, 2, _k2.float_as_int(-1.2)],
                                   [0, 2, 10, _k2.float_as_int(-2.2)],
                                   [1, 6, -1, _k2.float_as_int(-3.2)],
                                   [1, 3, 3, _k2.float_as_int(-4.2)],
                                   [2, 6, -1, _k2.float_as_int(-5.2)],
                                   [2, 4, 2, _k2.float_as_int(-6.2)],
                                   [3, 6, -1, _k2.float_as_int(-7.2)],
                                   [5, 0, 1, _k2.float_as_int(-8.2)]],
                                  dtype=torch.int32)

        fsa = k2.Fsa(fsa_tensor)

        expected_str = '''
            0 1 2 -1.2
            0 2 10 -2.2
            1 6 -1 -3.2
            1 3 3 -4.2
            2 6 -1 -5.2
            2 4 2 -6.2
            3 6 -1 -7.2
            5 0 1 -8.2
            6
        '''
        assert _remove_leading_spaces(expected_str) == \
                _remove_leading_spaces(k2.to_str(fsa))

        arcs = fsa.arcs.values()[:, :-1]
        assert isinstance(arcs, torch.Tensor)
        assert arcs.dtype == torch.int32
        assert arcs.device.type == 'cpu'
        assert arcs.shape == (8, 3), 'there should be 8 arcs'
        assert torch.all(
            torch.eq(arcs[0], torch.tensor([0, 1, 2], dtype=torch.int32)))

        assert torch.allclose(
            fsa.scores,
            torch.tensor([-1.2, -2.2, -3.2, -4.2, -5.2, -6.2, -7.2, -8.2],
                         dtype=torch.float32))

        fsa.scores *= -1

        assert torch.allclose(
            fsa.scores,
            torch.tensor([1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2],
                         dtype=torch.float32))
コード例 #2
0
    def test_transducer_from_tensor(self):
        for device in self.devices:
            fsa_tensor = torch.tensor(
                [[0, 1, 2, _k2.float_as_int(-1.2)],
                 [0, 2, 10, _k2.float_as_int(-2.2)],
                 [1, 6, -1, _k2.float_as_int(-4.2)],
                 [1, 3, 3, _k2.float_as_int(-3.2)],
                 [2, 6, -1, _k2.float_as_int(-5.2)],
                 [2, 4, 2, _k2.float_as_int(-6.2)],
                 [3, 6, -1, _k2.float_as_int(-7.2)],
                 [5, 0, 1, _k2.float_as_int(-8.2)]],
                dtype=torch.int32).to(device)
            aux_labels_tensor = torch.tensor([22, 100, 16, 33, 26, 22, 36, 50],
                                             dtype=torch.int32).to(device)
            fsa = k2.Fsa(fsa_tensor, aux_labels_tensor)
            assert fsa.aux_labels.dtype == torch.int32
            assert fsa.aux_labels.device.type == device.type
            assert torch.all(
                torch.eq(
                    fsa.aux_labels,
                    torch.tensor([22, 100, 16, 33, 26, 22, 36, 50],
                                 dtype=torch.int32).to(device)))

            assert torch.allclose(
                fsa.scores,
                torch.tensor([-1.2, -2.2, -4.2, -3.2, -5.2, -6.2, -7.2, -8.2],
                             dtype=torch.float32,
                             device=device))

            expected_str = '''
                0 1 2 22 -1.2
                0 2 10 100 -2.2
                1 6 -1 16 -4.2
                1 3 3 33 -3.2
                2 6 -1 26 -5.2
                2 4 2 22 -6.2
                3 6 -1 36 -7.2
                5 0 1 50 -8.2
                6
            '''
            assert _remove_leading_spaces(expected_str) == \
                    _remove_leading_spaces(k2.to_str(fsa))
コード例 #3
0
    def test_transducer_from_tensor(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))

        for device in devices:
            fsa_tensor = torch.tensor(
                [[0, 1, 2, _k2.float_as_int(-1.2)],
                 [0, 2, 10, _k2.float_as_int(-2.2)],
                 [1, 6, -1, _k2.float_as_int(-4.2)],
                 [1, 3, 3, _k2.float_as_int(-3.2)],
                 [2, 6, -1, _k2.float_as_int(-5.2)],
                 [2, 4, 2, _k2.float_as_int(-6.2)],
                 [3, 6, -1, _k2.float_as_int(-7.2)],
                 [5, 0, 1, _k2.float_as_int(-8.2)]],
                dtype=torch.int32).to(device)
            aux_labels_tensor = torch.tensor([22, 100, 16, 33, 26, 22, 36, 50],
                                             dtype=torch.int32).to(device)
            fsa = k2.Fsa(fsa_tensor, aux_labels_tensor)
            assert fsa.aux_labels.dtype == torch.int32
            assert fsa.aux_labels.device.type == device.type
            assert torch.allclose(
                fsa.aux_labels,
                torch.tensor([22, 100, 16, 33, 26, 22, 36, 50],
                             dtype=torch.int32).to(device))

            expected_str = '''
                0 1 2 22 -1.2
                0 2 10 100 -2.2
                1 6 -1 16 -4.2
                1 3 3 33 -3.2
                2 6 -1 26 -5.2
                2 4 2 22 -6.2
                3 6 -1 36 -7.2
                5 0 1 50 -8.2
                6
            '''
            assert _remove_leading_spaces(
                expected_str) == _remove_leading_spaces(k2.to_str(fsa))