def creator(path): dataset = {} for split in ('train', 'dev', 'test'): en_path = download.cached_download(en_url.format(split)) ja_path = download.cached_download(ja_url.format(split)) with io.open(en_path, 'rt') as en, io.open(ja_path, 'rt') as ja: dataset[split] = [(x.rstrip(os.linesep), y.rstrip(os.linesep)) for x, y in zip(en, ja)] with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset
def __init__(self, split: str = 'train') -> None: path = cached_download(IMDB_URL) tf = tarfile.open(path, 'r') cache_dir = Path(get_cache_root()) if not all((cache_dir / p).exists() for p in ALL): print(f'Extracting from {path}...') tf.extractall(cache_dir) if split == 'train': pos_dir = f'{cache_dir / TRAIN_DIR}/pos' neg_dir = f'{cache_dir / TRAIN_DIR}/neg' elif split == 'test': pos_dir = f'{cache_dir / TEST_DIR}/pos' neg_dir = f'{cache_dir / TEST_DIR}/neg' else: raise ValueError( f"only 'train' and 'test' are valid for 'split', but '{split}' is given." ) path = list( chain(Path(pos_dir).glob('*.txt'), Path(neg_dir).glob('*.txt'))) def map_func(x: Path) -> Tuple[str, int]: string = x.read_text() label = 0 if 'pos' in str(x) else 1 return (string, label) super().__init__(path, map_func)
def test_cache_exists(self): with mock.patch('os.path.exists') as f: f.return_value = True url = 'https://example.com' path = download.cached_download(url) self.assertEqual( path, f'{self.temp_dir}/_dl_cache/{hashlib.md5(url.encode("utf-8")).hexdigest()}' )
def creator(path): dataset = {} for split in ('train', 'dev', 'test'): data_path = download.cached_download( url.format(split if split != 'dev' else 'valid')) with io.open(data_path, 'rt') as f: dataset[split] = [line.rstrip(os.linesep) for line in f] with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset
def easyfile_creator(path): archive_path = download.cached_download(url) with zipfile.ZipFile(archive_path, 'r') as archive: print(f'Extracting to {root}...') archive.extractall(root) dataset = {} for split in ('train', 'dev', 'test'): filename = 'wiki.{}.tokens'.format(split if split != 'dev' else 'valid') dataset[split] = easyfile.TextFile(os.path.join(root, name, filename)) with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset
def __init__(self, split: str = 'train', version: int = 1) -> None: if version == 1: train_url = TRAIN_V1_URL dev_url = DEV_V1_URL elif version == 2: train_url = TRAIN_V2_URL dev_url = DEV_V2_URL else: raise ValueError( f"only 1 and 2 are valid for 'version', but {version} is given." ) if split == 'train': path = cached_download(train_url) elif split == 'dev': path = cached_download(dev_url) else: raise ValueError( f"only 'train' and 'dev' are valid for 'split', but '{split}' is given." ) dataset = RandomAccessText(path) super().__init__(dataset, json.loads)
def creator(path): dataset = {} fieldnames = ('quality', 'id1', 'id2', 'string1', 'string2') for split in ('train', 'test'): data_path = download.cached_download(url.format(split)) with io.open(data_path, 'r', encoding='utf-8') as f: f.readline() # skip header reader = csv.DictReader(f, delimiter='\t', fieldnames=fieldnames) dataset[split] = [dict(row) for row in reader] with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset
def list_creator(path): archive_path = download.cached_download(url) with zipfile.ZipFile(archive_path, 'r') as archive: dataset = {} path2key = {f'{name}/wiki.train.tokens': 'train', f'{name}/wiki.valid.tokens': 'dev', f'{name}/wiki.test.tokens': 'test'} for p, key in path2key.items(): print(f'Extracting {p}...') with archive.open(p) as f: lines = [line.decode('utf-8').rstrip(os.linesep) for line in f] dataset[key] = lines with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset
def creator(path): archive_path = download.cached_download(url) target_path = os.path.join(root, 'raw') with tarfile.open(archive_path, 'r') as archive: print(f'Extracting to {target_path}') archive.extractall(target_path) dataset = {} for split in ('train', 'dev', 'test'): src_path = f'{split if split != "dev" else "val"}.txt.src' tgt_path = f'{split if split != "dev" else "val"}.txt.tgt.tagged' dataset[split] = (easyfile.TextFile( os.path.join(target_path, src_path)), easyfile.TextFile( os.path.join(target_path, tgt_path))) with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset
def __init__(self, split: str = 'train') -> None: if split == 'train': en_path = cached_download(TRAIN_EN_URL) ja_path = cached_download(TRAIN_JA_URL) elif split == 'dev': en_path = cached_download(DEV_EN_URL) ja_path = cached_download(DEV_JA_URL) elif split == 'test': en_path = cached_download(TEST_EN_URL) ja_path = cached_download(TEST_JA_URL) else: raise ValueError(f"only 'train', 'dev' and 'test' are valid for 'split', but '{split}' is given.") super().__init__(source_file_path=en_path, target_file_path=ja_path)
def test_cached_download(self): with mock.patch('urllib.request.urlretrieve') as f: def urlretrieve(url, path): with open(path, 'w') as f: f.write('test') f.side_effect = urlretrieve cache_path = download.cached_download('https://example.com') self.assertEqual(f.call_count, 1) args, kwargs = f.call_args self.assertEqual(kwargs, {}) self.assertEqual(len(args), 2) # The second argument is a temporary path, and it is removed self.assertEqual(args[0], 'https://example.com') self.assertTrue(os.path.exists(cache_path)) with open(cache_path) as f: stored_data = f.read() self.assertEqual(stored_data, 'test')
def creator(path): archive_path = download.cached_download(url) with tarfile.open(archive_path, 'r') as archive: print(f'Extracting to {root}...') archive.extractall(root) extracted_path = os.path.join(root, 'aclImdb') dataset = {} for split in ('train', 'test'): pos_path = os.path.join(extracted_path, split, 'pos') neg_path = os.path.join(extracted_path, split, 'neg') dataset[split] = [x.path for x in os.scandir(pos_path) if x.is_file() and x.name.endswith('.txt')] + \ [x.path for x in os.scandir(neg_path) if x.is_file() and x.name.endswith('.txt')] with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset
def __init__(self, split: str = 'train') -> None: path = cached_download(CNN_DAILYMAIL_URL) tf = tarfile.open(path, 'r') cache_dir = Path(get_cache_directory('cnndm')) if not all((cache_dir / p).exists() for p in ALL): print(f'Extracting from {path}...') tf.extractall(cache_dir) if split == 'train': src_path = cache_dir / TRAIN_SOURCE_NAME tgt_path = cache_dir / TRAIN_TARGET_NAME elif split == 'dev': src_path = cache_dir / VAL_SOURCE_NAME tgt_path = cache_dir / VAL_TARGET_NAME elif split == 'test': src_path = cache_dir / TEST_SOURCE_NAME tgt_path = cache_dir / TEST_TARGET_NAME else: raise ValueError( f"only 'train', 'dev' and 'test' are valid for 'split', but '{split}' is given." ) super().__init__(source_file_path=str(src_path), target_file_path=str(tgt_path))
def test_file_exists(self): # Make an empty file which has the same name as the cache directory with open(os.path.join(self.temp_dir, '_dl_cache'), 'w'): pass with self.assertRaises(OSError): download.cached_download('https://example.com')
def test_fails_to_make_directory(self, f): f.side_effect = OSError() with self.assertRaises(OSError): download.cached_download('https://example.com')
def test_fails_to_make_directory(self): with mock.patch('os.makedirs') as f: f.side_effect = OSError() with self.assertRaises(OSError): download.cached_download('https://example.com')