def test_penntreebank(self): from torchtext.experimental.datasets import PennTreebank # smoke test to ensure penn treebank works properly train_dataset, valid_dataset, test_dataset = PennTreebank() train_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, train_dataset))) valid_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, valid_dataset))) test_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, test_dataset))) self._helper_test_func(len(train_data), 924412, train_data[20:25], [9919, 9920, 9921, 9922, 9188]) self._helper_test_func(len(test_data), 82114, test_data[30:35], [397, 93, 4, 16, 7]) self._helper_test_func(len(valid_data), 73339, valid_data[40:45], [0, 0, 78, 426, 196]) vocab = train_dataset.get_vocab() tokens_ids = [vocab[token] for token in 'the player characters rest'.split()] self.assertEqual(tokens_ids, [2, 2550, 3344, 1125]) # Add test for the subset of the standard datasets train_dataset, test_dataset = PennTreebank(data_select=('train', 'test')) train_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, train_dataset))) test_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, test_dataset))) self._helper_test_func(len(train_data), 924412, train_data[20:25], [9919, 9920, 9921, 9922, 9188]) self._helper_test_func(len(test_data), 82114, test_data[30:35], [397, 93, 4, 16, 7]) train_iter, test_iter = torchtext.experimental.datasets.raw.PennTreebank(data_select=('train', 'test')) self._helper_test_func(len(train_iter), 42068, next(iter(train_iter))[:15], ' aer banknote b') self._helper_test_func(len(test_iter), 3761, next(iter(test_iter))[:25], " no it was n't black mond") del train_iter, test_iter
def test_penntreebank(self): from torchtext.experimental.datasets import PennTreebank # smoke test to ensure wikitext2 works properly train_dataset, test_dataset, valid_dataset = PennTreebank() self.assertEqual(len(train_dataset), 924412) self.assertEqual(len(test_dataset), 82114) self.assertEqual(len(valid_dataset), 73339) vocab = train_dataset.get_vocab() tokens_ids = [vocab[token] for token in 'the player characters rest'.split()] self.assertEqual(tokens_ids, [2, 2550, 3344, 1125])
def test_penntreebank(self): from torchtext.experimental.datasets import PennTreebank # smoke test to ensure wikitext2 works properly train_dataset, test_dataset, valid_dataset = PennTreebank() self.assertEqual(len(train_dataset), 924412) self.assertEqual(len(test_dataset), 82114) self.assertEqual(len(valid_dataset), 73339) vocab = train_dataset.get_vocab() tokens_ids = [ vocab[token] for token in 'the player characters rest'.split() ] self.assertEqual(tokens_ids, [2, 2550, 3344, 1125]) # Delete the dataset after we're done to save disk space on CI datafile = os.path.join(self.project_root, ".data", 'ptb.train.txt') conditional_remove(datafile) datafile = os.path.join(self.project_root, ".data", 'ptb.test.txt') conditional_remove(datafile) datafile = os.path.join(self.project_root, ".data", 'ptb.valid.txt') conditional_remove(datafile)