Пример #1
0
    def test_text(self):
        fp = self.fp
        lines = self.lines

        data = TextDataset(fp.name)
        self.assertEqual(data._length, None)

        for x, y in zip(data, lines):
            self.assertEqual(x, y)

        for i, y in enumerate(lines):
            self.assertEqual(data[i], y)

        self.assertEqual(len(data), len(lines))
        self.assertEqual(data._length, len(lines))
        # check if length is cached
        self.assertEqual(len(data), len(lines))

        self.assertIsInstance(data._dataset, arrayfiles.TextFile)

        data = data.map(str.split)

        for x, y in zip(data, lines):
            self.assertEqual(x, y.split())

        self.assertIsInstance(data, lineflow.core.MapDataset)
        self.assertIsInstance(data._dataset, TextDataset)
Пример #2
0
    def test_zips_multiple_files(self):
        fp = self.fp
        lines = self.lines

        data = TextDataset([fp.name, fp.name], mode='zip')
        for x, y in zip(data, lines):
            self.assertTupleEqual(x, (y, y))
        for j, y in enumerate(lines):
            self.assertTupleEqual(data[j], (y, y))
        self.assertEqual(len(data), len(lines))
        self.assertEqual(data._length, len(lines))
        self.assertIsInstance(data._dataset, lineflow.core.ZipDataset)
        self.assertIsInstance(data.map(lambda x: x)._dataset, TextDataset)
Пример #3
0
    def test_concats_multiple_files(self):
        fp = self.fp
        lines = self.lines

        data = TextDataset([fp.name, fp.name], mode='concat')
        for x, y in zip(data, lines + lines):
            self.assertEqual(x, y)
        for j, y in enumerate(lines + lines):
            self.assertEqual(data[j], y)
        self.assertEqual(len(data), len(lines) * 2)
        self.assertEqual(data._length, len(lines) * 2)

        self.assertEqual(data[len(data) - 1], lines[-1])
        self.assertIsInstance(data._dataset, lineflow.core.ConcatDataset)
        self.assertIsInstance(data.map(lambda x: x)._dataset, TextDataset)
Пример #4
0
    def __len__(self):
        return int(math.ceil(len(self._dataset)) / float(self._batch_size))

    def __getitem__(self, index):
        return [
            self._dataset[i]
            for i in range(index * self._batch_size, (index + 1) *
                           self._batch_size)
        ]


if __name__ == '__main__':
    nlp = spacy.load('en_core_web_sm',
                     disable=['vectors', 'textcat', 'tagger', 'ner'])
    ds = TextDataset('dev-v1.1.jsonl').map(json.loads) \
        .map(lambda x: [token.text for token in nlp(x['question'])
                        if not token.is_space])

    # PyTorch
    print('PyTorch')
    loader = DataLoader(ds, batch_size=3, num_workers=4, shuffle=True)
    it = iter(loader)
    print(next(it))
    del it

    # Chainer
    print('Chainer')
    it = MultiprocessIterator(ds, batch_size=3, n_processes=4, shuffle=True)
    print(next(it))
    it.finalize()
Пример #5
0
        soup = BeautifulSoup(cand_txt)
        cand_txt = soup.get_text()

        class_split[i not in correct_long_answers] += 1
        cand_obj = {"question": question_txt, "context": cand_txt, "is_impossible": i not in correct_long_answers}
        print(cand_obj, flush=True)

        examples.append(cand_obj)

    return examples


s_time = time.time()
if os.path.exists(nq_save_path):
    print("File at save path already exists!")
else:
    print("Processing dataset!", flush=True)
    nq = TextDataset(nq_path).map(json.loads).flat_map(process_nq_json)

    print("Testing one batch from dataloader...", flush=True)
    loader = DataLoader(nq, batch_size=128, num_workers=4, shuffle=True)
    it = iter(loader)
    print(next(it))

    print("Saving dataset...", flush=True)
    nq.save(nq_save_path)
    print("Dataset saved to:", nq_save_path, flush=True)
e_time = time.time()
print("Script finished in", e_time - s_time)
print("is_impossible class split:", class_split)
Пример #6
0
 def test_raises_value_error_with_invalid_mode(self):
     with self.assertRaises(ValueError):
         TextDataset([self.fp.name, self.fp.name], mode='invalid_mode')