def test_restoration_integrity(self):
        char_table = CharacterTable()

        char_original = 'c'
        char_restored = char_table.decode(char_table.encode(char_original))

        self.assertEqual(char_restored, char_original)
    def test_consistency(self):
        char_original = 'a'

        first_encoding = CharacterTable().encode(char_original)
        second_encoding = CharacterTable().encode(char_original)

        self.assertEqual(first_encoding, second_encoding)
예제 #3
0
    def predict(self, input_sequence):
        # todo: refactor
        labels = self._inference_model.predict(input_sequence)
        char_table = CharacterTable()

        s = ''.join([char_table.decode(label) for label in labels])

        return s.strip()
    def test_decode_without_blanks_and_repeatitions(self):
        char_table = CharacterTable()

        decoder = CTCOutputDecoder(char_table)

        labels = [char_table.encode(ch) for ch in self.original_text]

        self.assertEqual(decoder.decode(labels), 'Helo, world!')
    def test_decode_one_label_sequence(self):
        char_table = CharacterTable()

        decoder = CTCOutputDecoder(char_table)

        original = 'c'
        labels = [char_table.encode(original)]

        self.assertEqual(decoder.decode(labels), original)
    def text_to_codes(self, text):
        char_table = CharacterTable()

        codes = [char_table.encode(ch) for ch in text]

        blank = len(char_table)
        seq = [blank]
        for code in codes:
            seq.append(code)
            seq.append(blank)

        return seq
예제 #7
0
def dummy_source():
    sin = 'HHHH    eee  lll  lll  ooo  ,,,  www   oooo  rrr   lll  ddd'
    sout = 'Hello, world'

    char_table = CharacterTable()

    codes = [char_table.encode(ch) for ch in sin]

    x = to_categorical(codes, num_classes=len(char_table))

    x = x.reshape(1, len(sin), -1)

    return PreLoadedSource(x, [sout])
    def test_decode_labels(self):
        char_table = CharacterTable()

        decoder = CTCOutputDecoder(char_table)

        blank = len(char_table)
        num_repeated = 4
        labels = [blank] * num_repeated
        for ch in self.original_text:
            label = char_table.encode(ch)
            labels.extend([label] * num_repeated)
            labels.extend([blank] * num_repeated)

        self.assertEqual(decoder.decode(labels), self.original_text)
    def test_decode_sequence_of_blanks(self):
        char_table = CharacterTable()

        decoder = CTCOutputDecoder(char_table)

        blanks = [len(char_table)] * 50

        self.assertEqual(decoder.decode(blanks), '')
    def setUp(self):
        self.seqs_in = [
            [[1, 1], [2, 2], [3, 3]],
            [[4, 4]]
        ]

        self.seqs_out = [
            [34, 85, 23],
            [28]
        ]

        char_table = CharacterTable()
        self.char_table = char_table
        self.start = char_table.encode(char_table.start)
        self.sentinel = char_table.encode(char_table.sentinel)
        self.adapter = Seq2seqAdapter(self.start, self.sentinel,
                                      num_classes=len(char_table))
    def create_distribution(self, codes):
        char_table = CharacterTable()

        Tx = len(codes)
        n = len(char_table) + 1
        a = np.zeros((Tx, n))
        for i, code in enumerate(codes):
            a[i, code] = 1.0

        return a
    def test_mapping_is_one_to_one(self):
        char_table = CharacterTable()
        decoded_chars = []
        for code in range(len(char_table)):
            ch = char_table.decode(code)
            decoded_chars.append(ch)

        self.assertEqual(
            len(decoded_chars), len(set(decoded_chars)),
            'Got duplicate characters from different codes: {}'.format(
                decoded_chars))

        encoded_chars = []
        for ch in decoded_chars:
            encoded_chars.append(char_table.encode(ch))

        self.assertEqual(
            len(encoded_chars), len(set(encoded_chars)),
            '2 or more characters got mapped to the same code:'.format(
                encoded_chars))
    def setUp(self):
        transitions = {
            ("hello", "world"): 0.75,
            ("hello", "hello"): 0.25,
            ("world", "hello"): 0.3,
            ("world", "world"): 0.7,
        }

        self.char_table = CharacterTable()
        self.hello_code = 0
        self.world_code = 1
        self.dictionary = WordDictionary(["hello", "world"], transitions)
예제 #14
0
    def decode_next(self, prev_y, prev_state):
        char_table = CharacterTable()
        epsilon = 0.001
        next_p = [epsilon] * len(char_table)
        next_state = [epsilon] * len(char_table)

        if self.counter >= len(self.result):
            code = 27
        else:
            code = self.result[self.counter]

        next_p[code] = 1.0

        self.counter += 1
        return next_p, next_state
예제 #15
0
 def decode_next(self, prev_y, prev_state):
     char_table = CharacterTable()
     epsilon = 0.001
     next_p = [epsilon] * len(char_table)
     next_state = [epsilon] * len(char_table)
     return next_p, next_state
 def test_decode_sequence_with_unknown_characters(self):
     char_table = CharacterTable()
     decoder = CTCOutputDecoder(char_table)
     labels = [165, 123568, 123586]
     self.assertEqual(decoder.decode(labels), '???')
 def test_decode_empty_sequence(self):
     char_table = CharacterTable()
     decoder = CTCOutputDecoder(char_table)
     labels = []
     self.assertEqual(decoder.decode(labels), '')
    def test_sentinel(self):
        char_table = CharacterTable()
        sentinel = char_table.sentinel

        decoded = char_table.decode(char_table.encode(sentinel))
        self.assertEqual(decoded, sentinel)
예제 #19
0
 def get_initial_state(self):
     char_table = CharacterTable()
     epsilon = 0.001
     return [epsilon] * len(char_table)
 def test_decode_out_of_alphabet(self):
     char_table = CharacterTable()
     res = char_table.decode(len(char_table))
     self.assertEqual(res, '?')