class TestVocabulary(unittest.TestCase): def setUp(self): self.file = "test/data/toy/train.de" sent = "Die Wahrheit ist, dass die Titanic – obwohl sie alle " \ "Kinokassenrekorde bricht – nicht gerade die aufregendste " \ "Geschichte vom Meer ist. GROẞ" # ẞ (in uppercase) requires Unicode self.word_list = sent.split() # only unique tokens self.char_list = list(sent) self.temp_file_char = "tmp.src.char" self.temp_file_word = "tmp.src.word" self.word_vocab = Vocabulary(tokens=sorted(list(set(self.word_list)))) self.char_vocab = Vocabulary(tokens=sorted(list(set(self.char_list)))) def testVocabularyFromList(self): self.assertEqual( len(self.word_vocab) - len(self.word_vocab.specials), len(set(self.word_list))) self.assertEqual( len(self.char_vocab) - len(self.char_vocab.specials), len(set(self.char_list))) expected_char_itos = [ '<unk>', '<pad>', '<s>', '</s>', ' ', ',', '.', 'D', 'G', 'K', 'M', 'O', 'R', 'T', 'W', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'v', 'w', 'ẞ', '–' ] self.assertEqual(self.char_vocab.itos, expected_char_itos) expected_word_itos = [ '<unk>', '<pad>', '<s>', '</s>', 'Die', 'GROẞ', 'Geschichte', 'Kinokassenrekorde', 'Meer', 'Titanic', 'Wahrheit', 'alle', 'aufregendste', 'bricht', 'dass', 'die', 'gerade', 'ist,', 'ist.', 'nicht', 'obwohl', 'sie', 'vom', '–' ] self.assertEqual(self.word_vocab.itos, expected_word_itos) def testVocabularyFromFile(self): # write vocabs to file and create new ones from those files self.word_vocab.to_file(self.temp_file_word) self.char_vocab.to_file(self.temp_file_char) word_vocab2 = Vocabulary(file=self.temp_file_word) char_vocab2 = Vocabulary(file=self.temp_file_char) self.assertEqual(self.word_vocab.itos, word_vocab2.itos) self.assertEqual(self.char_vocab.itos, char_vocab2.itos) os.remove(self.temp_file_char) os.remove(self.temp_file_word) def testIsUnk(self): self.assertTrue(self.word_vocab.is_unk("BLA")) self.assertFalse(self.word_vocab.is_unk("Die")) self.assertFalse(self.word_vocab.is_unk("GROẞ")) self.assertTrue(self.char_vocab.is_unk("x")) self.assertFalse(self.char_vocab.is_unk("d")) self.assertFalse(self.char_vocab.is_unk("ẞ"))
def build_vocab(field, max_size, min_freq, data, vocab_file=None): """ Builds vocabulary for a torchtext `field` :param field: :param max_size: :param min_freq: :param data: :param vocab_file: :return: """ # special symbols specials = [UNK_TOKEN, PAD_TOKEN, BOS_TOKEN, EOS_TOKEN] if vocab_file is not None: # load it from file vocab = Vocabulary(file=vocab_file) vocab.add_tokens(specials) else: # create newly def filter_min(counter, min_freq): """ Filter counter by min frequency """ filtered_counter = Counter({t: c for t, c in counter.items() if c >= min_freq}) return filtered_counter def sort_and_cut(counter, limit): """ Cut counter to most frequent, sorted numerically and alphabetically""" # sort by frequency, then alphabetically tokens_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) tokens_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) vocab_tokens = [i[0] for i in tokens_and_frequencies[:limit]] return vocab_tokens tokens = [] for i in data.examples: if field == "src": tokens.extend(i.src) elif field == "trg": tokens.extend(i.trg) counter = Counter(tokens) if min_freq > -1: counter = filter_min(counter, min_freq) vocab_tokens = specials + sort_and_cut(counter, max_size) assert vocab_tokens[DEFAULT_UNK_ID()] == UNK_TOKEN assert len(vocab_tokens) <= max_size + len(specials) vocab = Vocabulary(tokens=vocab_tokens) # check for all except for UNK token whether they are OOVs for s in specials[1:]: assert not vocab.is_unk(s) return vocab