def test_restoration_integrity(self): char_table = CharacterTable() char_original = 'c' char_restored = char_table.decode(char_table.encode(char_original)) self.assertEqual(char_restored, char_original)
def test_consistency(self): char_original = 'a' first_encoding = CharacterTable().encode(char_original) second_encoding = CharacterTable().encode(char_original) self.assertEqual(first_encoding, second_encoding)
def predict(self, input_sequence): # todo: refactor labels = self._inference_model.predict(input_sequence) char_table = CharacterTable() s = ''.join([char_table.decode(label) for label in labels]) return s.strip()
def test_decode_without_blanks_and_repeatitions(self): char_table = CharacterTable() decoder = CTCOutputDecoder(char_table) labels = [char_table.encode(ch) for ch in self.original_text] self.assertEqual(decoder.decode(labels), 'Helo, world!')
def test_decode_one_label_sequence(self): char_table = CharacterTable() decoder = CTCOutputDecoder(char_table) original = 'c' labels = [char_table.encode(original)] self.assertEqual(decoder.decode(labels), original)
def text_to_codes(self, text): char_table = CharacterTable() codes = [char_table.encode(ch) for ch in text] blank = len(char_table) seq = [blank] for code in codes: seq.append(code) seq.append(blank) return seq
def dummy_source(): sin = 'HHHH eee lll lll ooo ,,, www oooo rrr lll ddd' sout = 'Hello, world' char_table = CharacterTable() codes = [char_table.encode(ch) for ch in sin] x = to_categorical(codes, num_classes=len(char_table)) x = x.reshape(1, len(sin), -1) return PreLoadedSource(x, [sout])
def test_decode_labels(self): char_table = CharacterTable() decoder = CTCOutputDecoder(char_table) blank = len(char_table) num_repeated = 4 labels = [blank] * num_repeated for ch in self.original_text: label = char_table.encode(ch) labels.extend([label] * num_repeated) labels.extend([blank] * num_repeated) self.assertEqual(decoder.decode(labels), self.original_text)
def test_decode_sequence_of_blanks(self): char_table = CharacterTable() decoder = CTCOutputDecoder(char_table) blanks = [len(char_table)] * 50 self.assertEqual(decoder.decode(blanks), '')
def setUp(self): self.seqs_in = [ [[1, 1], [2, 2], [3, 3]], [[4, 4]] ] self.seqs_out = [ [34, 85, 23], [28] ] char_table = CharacterTable() self.char_table = char_table self.start = char_table.encode(char_table.start) self.sentinel = char_table.encode(char_table.sentinel) self.adapter = Seq2seqAdapter(self.start, self.sentinel, num_classes=len(char_table))
def create_distribution(self, codes): char_table = CharacterTable() Tx = len(codes) n = len(char_table) + 1 a = np.zeros((Tx, n)) for i, code in enumerate(codes): a[i, code] = 1.0 return a
def test_mapping_is_one_to_one(self): char_table = CharacterTable() decoded_chars = [] for code in range(len(char_table)): ch = char_table.decode(code) decoded_chars.append(ch) self.assertEqual( len(decoded_chars), len(set(decoded_chars)), 'Got duplicate characters from different codes: {}'.format( decoded_chars)) encoded_chars = [] for ch in decoded_chars: encoded_chars.append(char_table.encode(ch)) self.assertEqual( len(encoded_chars), len(set(encoded_chars)), '2 or more characters got mapped to the same code:'.format( encoded_chars))
def setUp(self): transitions = { ("hello", "world"): 0.75, ("hello", "hello"): 0.25, ("world", "hello"): 0.3, ("world", "world"): 0.7, } self.char_table = CharacterTable() self.hello_code = 0 self.world_code = 1 self.dictionary = WordDictionary(["hello", "world"], transitions)
def decode_next(self, prev_y, prev_state): char_table = CharacterTable() epsilon = 0.001 next_p = [epsilon] * len(char_table) next_state = [epsilon] * len(char_table) if self.counter >= len(self.result): code = 27 else: code = self.result[self.counter] next_p[code] = 1.0 self.counter += 1 return next_p, next_state
def decode_next(self, prev_y, prev_state): char_table = CharacterTable() epsilon = 0.001 next_p = [epsilon] * len(char_table) next_state = [epsilon] * len(char_table) return next_p, next_state
def test_decode_sequence_with_unknown_characters(self): char_table = CharacterTable() decoder = CTCOutputDecoder(char_table) labels = [165, 123568, 123586] self.assertEqual(decoder.decode(labels), '???')
def test_decode_empty_sequence(self): char_table = CharacterTable() decoder = CTCOutputDecoder(char_table) labels = [] self.assertEqual(decoder.decode(labels), '')
def test_sentinel(self): char_table = CharacterTable() sentinel = char_table.sentinel decoded = char_table.decode(char_table.encode(sentinel)) self.assertEqual(decoded, sentinel)
def get_initial_state(self): char_table = CharacterTable() epsilon = 0.001 return [epsilon] * len(char_table)
def test_decode_out_of_alphabet(self): char_table = CharacterTable() res = char_table.decode(len(char_table)) self.assertEqual(res, '?')