Ejemplo n.º 1
0
class TestUnicodeCharsVocabulary(unittest.TestCase):
    def setUp(self):
        words = ['the', '.', chr(256) + 't', '<S>', '</S>', '<UNK>']
        (_, tmp) = tempfile.mkstemp()
        with open(tmp, 'w') as fout:
            fout.write('\n'.join(words))
        self.vocab = UnicodeCharsVocabulary(tmp, 5)
        self._tmp = tmp

    def test_vocab_word_to_char_ids(self):
        char_ids = self.vocab.word_to_char_ids('th')
        expected = np.array([258, 116, 104, 259, 260], dtype=np.int32)
        self.assertTrue((char_ids == expected).all())

        char_ids = self.vocab.word_to_char_ids('thhhhh')
        expected = np.array([258, 116, 104, 104, 259])
        self.assertTrue((char_ids == expected).all())

        char_ids = self.vocab.word_to_char_ids(chr(256) + 't')
        expected = np.array([258, 196, 128, 116, 259], dtype=np.int32)
        self.assertTrue((char_ids == expected).all())

    def test_bos_eos(self):
        bos_ids = self.vocab.word_to_char_ids('<S>')
        self.assertTrue((bos_ids == self.vocab.bos_chars).all())

        bos_ids = self.vocab.word_char_ids[self.vocab.word_to_id('<S>')]
        self.assertTrue((bos_ids == self.vocab.bos_chars).all())

        eos_ids = self.vocab.word_to_char_ids('</S>')
        self.assertTrue((eos_ids == self.vocab.eos_chars).all())

        eos_ids = self.vocab.word_char_ids[self.vocab.word_to_id('</S>')]
        self.assertTrue((eos_ids == self.vocab.eos_chars).all())

    def test_vocab_encode_chars(self):
        sentence = ' '.join(['th', 'thhhhh', chr(256) + 't'])
        char_ids = self.vocab.encode_chars(sentence)
        expected = np.array(
            [[258, 256, 259, 260, 260], [258, 116, 104, 259, 260],
             [258, 116, 104, 104, 259], [258, 196, 128, 116, 259],
             [258, 257, 259, 260, 260]],
            dtype=np.int32)
        self.assertTrue((char_ids == expected).all())

    def test_vocab_encode_chars_reverse(self):
        sentence = ' '.join(reversed(['th', 'thhhhh', chr(256) + 't']))
        vocab = UnicodeCharsVocabulary(self._tmp, 5)
        char_ids = vocab.encode_chars(sentence, reverse=True)
        expected = np.array(
            [[258, 256, 259, 260, 260], [258, 116, 104, 259, 260],
             [258, 116, 104, 104, 259], [258, 196, 128, 116, 259],
             [258, 257, 259, 260, 260]],
            dtype=np.int32)[::-1, :]
        self.assertTrue((char_ids == expected).all())

    def tearDown(self):
        os.remove(self._tmp)
Ejemplo n.º 2
0
 def test_vocab_encode_chars_reverse(self):
     sentence = ' '.join(reversed(['th', 'thhhhh', chr(256) + 't']))
     vocab = UnicodeCharsVocabulary(self._tmp, 5)
     char_ids = vocab.encode_chars(sentence, reverse=True)
     expected = np.array(
         [[258, 256, 259, 260, 260], [258, 116, 104, 259, 260],
          [258, 116, 104, 104, 259], [258, 196, 128, 116, 259],
          [258, 257, 259, 260, 260]],
         dtype=np.int32)[::-1, :]
     self.assertTrue((char_ids == expected).all())