def test_text(self): fp = self.fp lines = self.lines data = TextDataset(fp.name) self.assertEqual(data._length, None) for x, y in zip(data, lines): self.assertEqual(x, y) for i, y in enumerate(lines): self.assertEqual(data[i], y) self.assertEqual(len(data), len(lines)) self.assertEqual(data._length, len(lines)) # check if length is cached self.assertEqual(len(data), len(lines)) self.assertIsInstance(data._dataset, arrayfiles.TextFile) data = data.map(str.split) for x, y in zip(data, lines): self.assertEqual(x, y.split()) self.assertIsInstance(data, lineflow.core.MapDataset) self.assertIsInstance(data._dataset, TextDataset)
def test_zips_multiple_files(self): fp = self.fp lines = self.lines data = TextDataset([fp.name, fp.name], mode='zip') for x, y in zip(data, lines): self.assertTupleEqual(x, (y, y)) for j, y in enumerate(lines): self.assertTupleEqual(data[j], (y, y)) self.assertEqual(len(data), len(lines)) self.assertEqual(data._length, len(lines)) self.assertIsInstance(data._dataset, lineflow.core.ZipDataset) self.assertIsInstance(data.map(lambda x: x)._dataset, TextDataset)
def test_concats_multiple_files(self): fp = self.fp lines = self.lines data = TextDataset([fp.name, fp.name], mode='concat') for x, y in zip(data, lines + lines): self.assertEqual(x, y) for j, y in enumerate(lines + lines): self.assertEqual(data[j], y) self.assertEqual(len(data), len(lines) * 2) self.assertEqual(data._length, len(lines) * 2) self.assertEqual(data[len(data) - 1], lines[-1]) self.assertIsInstance(data._dataset, lineflow.core.ConcatDataset) self.assertIsInstance(data.map(lambda x: x)._dataset, TextDataset)
def __len__(self): return int(math.ceil(len(self._dataset)) / float(self._batch_size)) def __getitem__(self, index): return [ self._dataset[i] for i in range(index * self._batch_size, (index + 1) * self._batch_size) ] if __name__ == '__main__': nlp = spacy.load('en_core_web_sm', disable=['vectors', 'textcat', 'tagger', 'ner']) ds = TextDataset('dev-v1.1.jsonl').map(json.loads) \ .map(lambda x: [token.text for token in nlp(x['question']) if not token.is_space]) # PyTorch print('PyTorch') loader = DataLoader(ds, batch_size=3, num_workers=4, shuffle=True) it = iter(loader) print(next(it)) del it # Chainer print('Chainer') it = MultiprocessIterator(ds, batch_size=3, n_processes=4, shuffle=True) print(next(it)) it.finalize()
soup = BeautifulSoup(cand_txt) cand_txt = soup.get_text() class_split[i not in correct_long_answers] += 1 cand_obj = {"question": question_txt, "context": cand_txt, "is_impossible": i not in correct_long_answers} print(cand_obj, flush=True) examples.append(cand_obj) return examples s_time = time.time() if os.path.exists(nq_save_path): print("File at save path already exists!") else: print("Processing dataset!", flush=True) nq = TextDataset(nq_path).map(json.loads).flat_map(process_nq_json) print("Testing one batch from dataloader...", flush=True) loader = DataLoader(nq, batch_size=128, num_workers=4, shuffle=True) it = iter(loader) print(next(it)) print("Saving dataset...", flush=True) nq.save(nq_save_path) print("Dataset saved to:", nq_save_path, flush=True) e_time = time.time() print("Script finished in", e_time - s_time) print("is_impossible class split:", class_split)
def test_raises_value_error_with_invalid_mode(self): with self.assertRaises(ValueError): TextDataset([self.fp.name, self.fp.name], mode='invalid_mode')