def test_penntreebank(self): from torchtext.experimental.datasets import PennTreebank # smoke test to ensure penn treebank works properly train_dataset, valid_dataset, test_dataset = PennTreebank() train_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, train_dataset))) valid_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, valid_dataset))) test_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, test_dataset))) self._helper_test_func(len(train_data), 924412, train_data[20:25], [9919, 9920, 9921, 9922, 9188]) self._helper_test_func(len(test_data), 82114, test_data[30:35], [397, 93, 4, 16, 7]) self._helper_test_func(len(valid_data), 73339, valid_data[40:45], [0, 0, 78, 426, 196]) vocab = train_dataset.get_vocab() tokens_ids = [vocab[token] for token in 'the player characters rest'.split()] self.assertEqual(tokens_ids, [2, 2550, 3344, 1125]) # Add test for the subset of the standard datasets train_dataset, test_dataset = PennTreebank(data_select=('train', 'test')) train_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, train_dataset))) test_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, test_dataset))) self._helper_test_func(len(train_data), 924412, train_data[20:25], [9919, 9920, 9921, 9922, 9188]) self._helper_test_func(len(test_data), 82114, test_data[30:35], [397, 93, 4, 16, 7]) train_iter, test_iter = torchtext.experimental.datasets.raw.PennTreebank(data_select=('train', 'test')) self._helper_test_func(len(train_iter), 42068, next(iter(train_iter))[:15], ' aer banknote b') self._helper_test_func(len(test_iter), 3761, next(iter(test_iter))[:25], " no it was n't black mond") del train_iter, test_iter
def get_ptb(conf): """ Return PennTreeBank iterators """ # raw data train_iter, test_iter, valid_iter = PennTreebank(split=('train', 'test', 'valid')) train_iter_copy, test_iter_copy, valid_iter_copy = PennTreebank(split=('train', 'test', 'valid')) # loader train, test, valid, vocab = load_dataset(train_iter, test_iter, valid_iter, train_iter_copy, test_iter_copy, valid_iter_copy, conf) return train, test, valid, vocab
def test_penntreebank(self): from torchtext.experimental.datasets import PennTreebank # smoke test to ensure wikitext2 works properly train_dataset, test_dataset, valid_dataset = PennTreebank() self.assertEqual(len(train_dataset), 924412) self.assertEqual(len(test_dataset), 82114) self.assertEqual(len(valid_dataset), 73339) vocab = train_dataset.get_vocab() tokens_ids = [vocab[token] for token in 'the player characters rest'.split()] self.assertEqual(tokens_ids, [2, 2550, 3344, 1125])
def load_iter(): SEED = 1 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') random.seed(SEED) torch.manual_seed(SEED) spacy_en = spacy.load('en') def tokenize_en(text): """ Tokenizes English text from a string into a list of strings (tokens) """ return [tok.text for tok in spacy_en.tokenizer(text)] TEXT = Field(tokenize=tokenize_en, init_token='<sos>', eos_token='<eos>', lower=True) train_set, valid_set, test_set = PennTreebank.splits(TEXT) TEXT.build_vocab(train_set) train_iter, valid_iter, test_iter = BPTTIterator.splits( (train_set, valid_set, test_set), batch_size=64, bptt_len=6, device=device) return train_iter, valid_iter, test_iter, TEXT
def test_penntreebank(self): from torchtext.experimental.datasets import PennTreebank # smoke test to ensure wikitext2 works properly train_dataset, test_dataset, valid_dataset = PennTreebank() self.assertEqual(len(train_dataset), 924412) self.assertEqual(len(test_dataset), 82114) self.assertEqual(len(valid_dataset), 73339) vocab = train_dataset.get_vocab() tokens_ids = [ vocab[token] for token in 'the player characters rest'.split() ] self.assertEqual(tokens_ids, [2, 2550, 3344, 1125]) # Delete the dataset after we're done to save disk space on CI datafile = os.path.join(self.project_root, ".data", 'ptb.train.txt') conditional_remove(datafile) datafile = os.path.join(self.project_root, ".data", 'ptb.test.txt') conditional_remove(datafile) datafile = os.path.join(self.project_root, ".data", 'ptb.valid.txt') conditional_remove(datafile)