def test_vocab_load_and_save(self): token_to_freq = { 'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2 } sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = Vocab(c, min_freq=3) expected_itos = ['<unk>', '<pad>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) vocab_path = os.path.join(self.test_dir, 'vocab.pt') torch.save(v, vocab_path) loaded_v = torch.load(vocab_path) self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
def test_vocab_specials_order(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) # add specials into vocabulary at first v = Vocab(c, specials=['<pad>', '<unk>']) expected_itos = ['<pad>', '<unk>', 'a', 'b', 'c'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(dict(v.get_stoi()), expected_stoi) # add specials into vocabulary at last v = Vocab(c, specials=['<pad>', '<unk>'], specials_first=False) expected_itos = ['a', 'b', 'c', '<pad>', '<unk>'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(dict(v.get_stoi()), expected_stoi)
def test_vocab_insert_token(self): c = OrderedDict({'<unk>': 2, 'a': 2}) # add item to end v = Vocab(c) v.insert_token('b', 2) expected_itos = ['<unk>', 'a', 'b'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) # add item to middle v = Vocab(c) v.insert_token('b', 0) expected_itos = ['b', '<unk>', 'a'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi)
def test_vocab_basic(self): token_to_freq = { 'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2 } sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = Vocab(c, min_freq=3) expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world', '<unk>'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi)