def test_restoration_integrity(self):
        char_table = CharacterTable()

        char_original = 'c'
        char_restored = char_table.decode(char_table.encode(char_original))

        self.assertEqual(char_restored, char_original)
    def test_decode_without_blanks_and_repeatitions(self):
        char_table = CharacterTable()

        decoder = CTCOutputDecoder(char_table)

        labels = [char_table.encode(ch) for ch in self.original_text]

        self.assertEqual(decoder.decode(labels), 'Helo, world!')
    def setUp(self):
        self.seqs_in = [
            [[1, 1], [2, 2], [3, 3]],
            [[4, 4]]
        ]

        self.seqs_out = [
            [34, 85, 23],
            [28]
        ]

        char_table = CharacterTable()
        self.char_table = char_table
        self.start = char_table.encode(char_table.start)
        self.sentinel = char_table.encode(char_table.sentinel)
        self.adapter = Seq2seqAdapter(self.start, self.sentinel,
                                      num_classes=len(char_table))
    def test_decode_one_label_sequence(self):
        char_table = CharacterTable()

        decoder = CTCOutputDecoder(char_table)

        original = 'c'
        labels = [char_table.encode(original)]

        self.assertEqual(decoder.decode(labels), original)
    def text_to_codes(self, text):
        char_table = CharacterTable()

        codes = [char_table.encode(ch) for ch in text]

        blank = len(char_table)
        seq = [blank]
        for code in codes:
            seq.append(code)
            seq.append(blank)

        return seq
示例#6
0
def dummy_source():
    sin = 'HHHH    eee  lll  lll  ooo  ,,,  www   oooo  rrr   lll  ddd'
    sout = 'Hello, world'

    char_table = CharacterTable()

    codes = [char_table.encode(ch) for ch in sin]

    x = to_categorical(codes, num_classes=len(char_table))

    x = x.reshape(1, len(sin), -1)

    return PreLoadedSource(x, [sout])
    def test_decode_labels(self):
        char_table = CharacterTable()

        decoder = CTCOutputDecoder(char_table)

        blank = len(char_table)
        num_repeated = 4
        labels = [blank] * num_repeated
        for ch in self.original_text:
            label = char_table.encode(ch)
            labels.extend([label] * num_repeated)
            labels.extend([blank] * num_repeated)

        self.assertEqual(decoder.decode(labels), self.original_text)
    def test_mapping_is_one_to_one(self):
        char_table = CharacterTable()
        decoded_chars = []
        for code in range(len(char_table)):
            ch = char_table.decode(code)
            decoded_chars.append(ch)

        self.assertEqual(
            len(decoded_chars), len(set(decoded_chars)),
            'Got duplicate characters from different codes: {}'.format(
                decoded_chars))

        encoded_chars = []
        for ch in decoded_chars:
            encoded_chars.append(char_table.encode(ch))

        self.assertEqual(
            len(encoded_chars), len(set(encoded_chars)),
            '2 or more characters got mapped to the same code:'.format(
                encoded_chars))
    def test_sentinel(self):
        char_table = CharacterTable()
        sentinel = char_table.sentinel

        decoded = char_table.decode(char_table.encode(sentinel))
        self.assertEqual(decoded, sentinel)