def test_dunder_init(self): text = easyfile.TextFile(self.fp.name) self.assertEqual(text._path, self.fp.name) self.assertEqual(text._encoding, 'utf-8') self.assertFalse(text._ready) self.assertIsNone(text._length) self.assertIsNone(text._offsets) self.assertIsNone(text._mm)
def __init__(self, paths: Union[str, List[str]], encoding: str = 'utf-8', mode: str = 'zip') -> None: if isinstance(paths, str): dataset = easyfile.TextFile(paths, encoding) elif isinstance(paths, list): if mode == 'zip': dataset = ZipDataset( *[easyfile.TextFile(p, encoding) for p in paths]) elif mode == 'concat': dataset = ConcatDataset( *[easyfile.TextFile(p, encoding) for p in paths]) else: raise ValueError( f"only 'zip' and 'concat' are valid for 'mode', but '{mode}' is given." ) super().__init__(dataset)
def creator(path): archive_path = download.cached_download(url) target_path = os.path.join(root, 'raw') with tarfile.open(archive_path, 'r') as archive: print(f'Extracting to {target_path}') archive.extractall(target_path) dataset = {} for split in ('train', 'dev', 'test'): src_path = f'{split if split != "dev" else "val"}.txt.src' tgt_path = f'{split if split != "dev" else "val"}.txt.tgt.tagged' dataset[split] = (easyfile.TextFile( os.path.join(target_path, src_path)), easyfile.TextFile( os.path.join(target_path, tgt_path))) with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset
def easyfile_creator(path): archive_path = download.cached_download(url) with zipfile.ZipFile(archive_path, 'r') as archive: print(f'Extracting to {root}...') archive.extractall(root) dataset = {} for split in ('train', 'dev', 'test'): filename = 'wiki.{}.tokens'.format(split if split != 'dev' else 'valid') dataset[split] = easyfile.TextFile(os.path.join(root, name, filename)) with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset
def test_dunder_setstate(self): text = easyfile.TextFile(self.fp.name) state = text.__getstate__() self.assertNotIn('_mm', state) text.__setstate__(state) self.assertIn('_mm', text.__dict__)
def test_dunder_len(self): text = easyfile.TextFile(self.fp.name) self.assertEqual(len(text), self.length)
def test_raises_index_error_with_invalid_index(self): text = easyfile.TextFile(self.fp.name) with self.assertRaises(IndexError): text[self.length] with self.assertRaises(IndexError): text[-self.length-1]
def test_slices_items(self): text = easyfile.TextFile(self.fp.name) self.assertSequenceEqual(text[:self.length], text)
def test_iterates_each_line(self): text = easyfile.TextFile(self.fp.name) for i, x in enumerate(text): self.assertEqual(x, f'line #{i}')
def test_supports_random_access(self): text = easyfile.TextFile(self.fp.name) for i in range(self.length): self.assertEqual(text[i], f'line #{i}') self.assertEqual(text[i - self.length], f'line #{i}')