Пример #1
0
    def test_vocab_without_unk(self):
        c = Counter({
            'hello': 4,
            'world': 3,
            'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5,
            'freq_too_low': 2
        })
        oov_word = 'OOVWORD'
        self.assertNotIn(oov_word, c)

        # tests for specials_first=True
        v_first = vocab.Vocab(c,
                              min_freq=3,
                              specials=['<pad>'],
                              specials_first=True)
        expected_itos_first = ['<pad>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
        expected_stoi_first = {
            x: index
            for index, x in enumerate(expected_itos_first)
        }
        self.assertEqual(v_first.itos, expected_itos_first)
        self.assertEqual(dict(v_first.stoi), expected_stoi_first)
        self.assertNotIn(oov_word, v_first.itos)
        self.assertNotIn(oov_word, v_first.stoi)

        # tests for specials_first=False
        v_last = vocab.Vocab(c,
                             min_freq=3,
                             specials=['<pad>'],
                             specials_first=False)
        expected_itos_last = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world', '<pad>']
        expected_stoi_last = {
            x: index
            for index, x in enumerate(expected_itos_last)
        }
        self.assertEqual(v_last.itos, expected_itos_last)
        self.assertEqual(dict(v_last.stoi), expected_stoi_last)
        self.assertNotIn(oov_word, v_last.itos)
        self.assertNotIn(oov_word, v_last.stoi)

        # check if pad is mapped to the first index
        self.assertEqual(v_first.stoi['<pad>'], 0)
        # check if pad is mapped to the last index
        self.assertEqual(v_last.stoi['<pad>'], max(v_last.stoi.values()))

        # check if an oovword is not in vocab and a default unk_id is not assigned to it
        self.assertRaises(KeyError, v_first.stoi.__getitem__, oov_word)
        self.assertRaises(KeyError, v_last.stoi.__getitem__, oov_word)
Пример #2
0
 def test_has_unk(self):
     c = Counter({
         'hello': 4,
         'world': 3,
         'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5,
         'freq_too_low': 2
     })
     v = vocab.Vocab(c)
     self.assertEqual(v['not_in_it'], 0)
Пример #3
0
    def test_vocab_specials_first(self):
        c = Counter("a a b b c c".split())

        # add specials into vocabulary at first
        v = vocab.Vocab(c, max_size=2, specials=['<pad>', '<eos>'])
        expected_itos = ['<pad>', '<eos>', 'a', 'b']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}
        self.assertEqual(v.itos, expected_itos)
        self.assertEqual(dict(v.stoi), expected_stoi)

        # add specials into vocabulary at last
        v = vocab.Vocab(c,
                        max_size=2,
                        specials=['<pad>', '<eos>'],
                        specials_first=False)
        expected_itos = ['a', 'b', '<pad>', '<eos>']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}
        self.assertEqual(v.itos, expected_itos)
        self.assertEqual(dict(v.stoi), expected_stoi)
Пример #4
0
 def test_serialization(self):
     c = Counter({
         'hello': 4,
         'world': 3,
         'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5,
         'freq_too_low': 2
     })
     v = vocab.Vocab(c, min_freq=3, specials=['<unk>', '<pad>', '<bos>'])
     pickle_path = os.path.join(self.test_dir, "vocab.pkl")
     pickle.dump(v, open(pickle_path, "wb"))
     v_loaded = pickle.load(open(pickle_path, "rb"))
     assert v == v_loaded
Пример #5
0
    def test_vocab_basic(self):
        c = Counter({
            'hello': 4,
            'world': 3,
            'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5,
            'freq_too_low': 2
        })
        v = vocab.Vocab(c, min_freq=3, specials=['<unk>', '<pad>', '<bos>'])

        expected_itos = [
            '<unk>', '<pad>', '<bos>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'
        ]
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}
        self.assertEqual(v.itos, expected_itos)
        self.assertEqual(dict(v.stoi), expected_stoi)
Пример #6
0
 def test_errors(self):
     c = Counter({
         'hello': 4,
         'world': 3,
         'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5,
         'freq_too_low': 2
     })
     with self.assertRaises(ValueError):
         # Test proper error raised when using unknown string alias
         vocab.Vocab(c,
                     min_freq=3,
                     specials=['<unk>', '<pad>', '<bos>'],
                     vectors=["fasttext.english.300d"])
         vocab.Vocab(c,
                     min_freq=3,
                     specials=['<unk>', '<pad>', '<bos>'],
                     vectors="fasttext.english.300d")
     with self.assertRaises(ValueError):
         # Test proper error is raised when vectors argument is
         # non-string or non-Vectors
         vocab.Vocab(c,
                     min_freq=3,
                     specials=['<unk>', '<pad>', '<bos>'],
                     vectors={"word": [1, 2, 3]})
Пример #7
0
 def test_vocab_set_vectors(self):
     c = Counter({
         'hello': 4,
         'world': 3,
         'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5,
         'test': 4,
         'freq_too_low': 2
     })
     v = vocab.Vocab(c, min_freq=3, specials=['<unk>', '<pad>', '<bos>'])
     stoi = {"hello": 0, "world": 1, "test": 2}
     vectors = torch.FloatTensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])
     dim = 2
     v.set_vectors(stoi, vectors, dim)
     expected_vectors = np.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0],
                                  [0.0, 0.0], [0.1, 0.2], [0.5, 0.6],
                                  [0.3, 0.4]])
     self.assertEqual(v.vectors, expected_vectors, exact_dtype=False)
Пример #8
0
    def test_serialization_backcompat(self):
        # Test whether loading works on models saved in which
        #  the state was not required to have an "unk_index".
        c = Counter({
            'hello': 4,
            'world': 3,
            'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5,
            'freq_too_low': 2
        })
        v = vocab.Vocab(c, min_freq=3, specials=['<pad>',
                                                 '<bos>'])  # no unk special
        # Mock old vocabulary
        del v.__dict__["unk_index"]

        pickle_path = os.path.join(self.test_dir, "vocab.pkl")
        pickle.dump(v, open(pickle_path, "wb"))
        v_loaded = pickle.load(open(pickle_path, "rb"))
        assert v == v_loaded
Пример #9
0
    def build_vocab(self, *args, **kwargs):
        sources = []
        for arg in args:
            if isinstance(arg, textdata.Dataset):
                sources += [
                    getattr(arg, name)
                    for name, field in arg.fields.items()
                    if field is self
                ]
            else:
                sources.append(arg)

        counter = Counter()
        for data in sources:
            for x in data:
                if len(x) > 0:
                    counter.update(x[0])
        specials = [self.unk_token, self.pad_token]
        self.vocab = vocab.Vocab(counter, specials=specials, **kwargs)
Пример #10
0
    def build_vocab(self, *args, **kwargs):
        sources = []
        for arg in args:
            if isinstance(arg, textdata.Dataset):
                sources += [
                    getattr(arg, name) for name, field in arg.fields.items()
                    if field is self
                ]
            else:
                sources.append(arg)

        counter = Counter()
        for data in sources:
            # data is the return value of preprocess().
            for sentence in data:
                for word_chars in sentence:
                    # update treats word as an iterable, so this will add all
                    # the characters from the word, not the word itself.
                    counter.update(word_chars)
        specials = [self.unk_token, self.pad_token]

        self.vocab = vocab.Vocab(counter, specials=specials, **kwargs)