예제 #1
0
class TestVocabulary(unittest.TestCase):
    def setUp(self):
        self.file = "test/data/toy/train.de"
        sent = "Die Wahrheit ist, dass die Titanic – obwohl sie alle " \
               "Kinokassenrekorde bricht – nicht gerade die aufregendste " \
               "Geschichte vom Meer ist. GROẞ"  # ẞ (in uppercase) requires Unicode
        self.word_list = sent.split()  # only unique tokens
        self.char_list = list(sent)
        self.temp_file_char = "tmp.src.char"
        self.temp_file_word = "tmp.src.word"
        self.word_vocab = Vocabulary(tokens=sorted(list(set(self.word_list))))
        self.char_vocab = Vocabulary(tokens=sorted(list(set(self.char_list))))

    def testVocabularyFromList(self):
        self.assertEqual(
            len(self.word_vocab) - len(self.word_vocab.specials),
            len(set(self.word_list)))
        self.assertEqual(
            len(self.char_vocab) - len(self.char_vocab.specials),
            len(set(self.char_list)))
        expected_char_itos = [
            '<unk>', '<pad>', '<s>', '</s>', ' ', ',', '.', 'D', 'G', 'K', 'M',
            'O', 'R', 'T', 'W', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i',
            'k', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'v', 'w', 'ẞ', '–'
        ]

        self.assertEqual(self.char_vocab.itos, expected_char_itos)
        expected_word_itos = [
            '<unk>', '<pad>', '<s>', '</s>', 'Die', 'GROẞ', 'Geschichte',
            'Kinokassenrekorde', 'Meer', 'Titanic', 'Wahrheit', 'alle',
            'aufregendste', 'bricht', 'dass', 'die', 'gerade', 'ist,', 'ist.',
            'nicht', 'obwohl', 'sie', 'vom', '–'
        ]
        self.assertEqual(self.word_vocab.itos, expected_word_itos)

    def testVocabularyFromFile(self):
        # write vocabs to file and create new ones from those files
        self.word_vocab.to_file(self.temp_file_word)
        self.char_vocab.to_file(self.temp_file_char)

        word_vocab2 = Vocabulary(file=self.temp_file_word)
        char_vocab2 = Vocabulary(file=self.temp_file_char)
        self.assertEqual(self.word_vocab.itos, word_vocab2.itos)
        self.assertEqual(self.char_vocab.itos, char_vocab2.itos)
        os.remove(self.temp_file_char)
        os.remove(self.temp_file_word)

    def testIsUnk(self):
        self.assertTrue(self.word_vocab.is_unk("BLA"))
        self.assertFalse(self.word_vocab.is_unk("Die"))
        self.assertFalse(self.word_vocab.is_unk("GROẞ"))
        self.assertTrue(self.char_vocab.is_unk("x"))
        self.assertFalse(self.char_vocab.is_unk("d"))
        self.assertFalse(self.char_vocab.is_unk("ẞ"))
예제 #2
0
def build_vocab(field, max_size, min_freq, data, vocab_file=None):
    """
    Builds vocabulary for a torchtext `field`

    :param field:
    :param max_size:
    :param min_freq:
    :param data:
    :param vocab_file:
    :return:
    """

    # special symbols
    specials = [UNK_TOKEN, PAD_TOKEN, BOS_TOKEN, EOS_TOKEN]

    if vocab_file is not None:
        # load it from file
        vocab = Vocabulary(file=vocab_file)
        vocab.add_tokens(specials)
    else:
        # create newly
        def filter_min(counter, min_freq):
            """ Filter counter by min frequency """
            filtered_counter = Counter({t: c for t, c in counter.items()
                                   if c >= min_freq})
            return filtered_counter

        def sort_and_cut(counter, limit):
            """ Cut counter to most frequent,
            sorted numerically and alphabetically"""
            # sort by frequency, then alphabetically
            tokens_and_frequencies = sorted(counter.items(),
                                            key=lambda tup: tup[0])
            tokens_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
            vocab_tokens = [i[0] for i in tokens_and_frequencies[:limit]]
            return vocab_tokens

        tokens = []
        for i in data.examples:
            if field == "src":
                tokens.extend(i.src)
            elif field == "trg":
                tokens.extend(i.trg)

        counter = Counter(tokens)
        if min_freq > -1:
            counter = filter_min(counter, min_freq)
        vocab_tokens = specials + sort_and_cut(counter, max_size)
        assert vocab_tokens[DEFAULT_UNK_ID()] == UNK_TOKEN
        assert len(vocab_tokens) <= max_size + len(specials)
        vocab = Vocabulary(tokens=vocab_tokens)

    # check for all except for UNK token whether they are OOVs
    for s in specials[1:]:
        assert not vocab.is_unk(s)

    return vocab