Esempio n. 1
0
 def test_case_1(self):
     # empty fsa
     array_size = k2host.IntArray2Size(0, 0)
     fsa_in = k2host.Fsa.create_fsa_with_size(array_size)
     indexes = torch.IntTensor([0, 1, 3, 6, 7])
     data = torch.IntTensor([1, 2, 3, 4, 5, 6, 7])
     labels_in = k2host.AuxLabels(indexes, data)
     inverter = k2host.FstInverter(fsa_in, labels_in)
     fsa_size = k2host.IntArray2Size()
     aux_size = k2host.IntArray2Size()
     inverter.get_sizes(fsa_size, aux_size)
     self.assertEqual(aux_size.size1, 0)
     self.assertEqual(aux_size.size2, 0)
     fsa_out = k2host.Fsa.create_fsa_with_size(fsa_size)
     labels_out = k2host.AuxLabels.create_array_with_size(aux_size)
     inverter.get_output(fsa_out, labels_out)
     self.assertTrue(k2host.is_empty(fsa_out))
     self.assertTrue(labels_out.empty())
Esempio n. 2
0
    def test_case_3(self):
        # non-top-sorted input FSA
        s = r'''
        0 1 1 0
        0 1 0 0
        0 3 2 0
        1 2 3 0
        1 3 4 0
        2 1 5 0
        2 5 -1 0
        3 1 6 0
        4 5 -1 0
        5
        '''

        fsa_in = k2host.str_to_fsa(s)
        indexes = torch.IntTensor([0, 2, 3, 3, 6, 6, 7, 8, 10, 11])
        data = torch.IntTensor([1, 2, 3, 5, 6, 7, 8, -1, 9, 10, -1])
        labels_in = k2host.AuxLabels(indexes, data)
        inverter = k2host.FstInverter(fsa_in, labels_in)
        fsa_size = k2host.IntArray2Size()
        aux_size = k2host.IntArray2Size()
        inverter.get_sizes(fsa_size, aux_size)
        fsa_out = k2host.Fsa.create_fsa_with_size(fsa_size)
        labels_out = k2host.AuxLabels.create_array_with_size(aux_size)
        inverter.get_output(fsa_out, labels_out)
        expected_arc_indexes = torch.IntTensor(
            [0, 3, 4, 5, 7, 8, 9, 11, 12, 13, 13])
        expected_arcs = torch.IntTensor([[0, 1, 1, 0], [0, 3, 3, 0],
                                         [0, 7, 0, 0], [1, 3, 2, 0],
                                         [2, 3, 10, 0], [3, 4, 5, 0],
                                         [3, 7, 0, 0], [4, 5, 6, 0],
                                         [5, 6, 7, 0], [6, 3, 8, 0],
                                         [6, 9, -1, 0], [7, 2, 9, 0],
                                         [8, 9, -1, 0]])
        self.assertTrue(torch.equal(fsa_out.indexes, expected_arc_indexes))
        self.assertTrue(torch.equal(fsa_out.data, expected_arcs))
        expected_label_indexes = torch.IntTensor(
            [0, 0, 0, 1, 2, 3, 3, 4, 4, 5, 6, 7, 7, 8])
        expected_labels = torch.IntTensor([2, 1, 6, 4, 3, 5, -1, -1])
        self.assertTrue(torch.equal(labels_out.indexes,
                                    expected_label_indexes))
        self.assertTrue(torch.equal(labels_out.data, expected_labels))