def test_new_file(self): def create(path): with open(path, 'w') as f: f.write('test') creator = mock.Mock() creator.side_effect = create loader = mock.Mock() dir_path = tempfile.mkdtemp() # This file always does not exists as the directory is new. path = os.path.join(dir_path, 'cache') try: download.cache_or_load_file(path, creator, loader) self.assertEqual(creator.call_count, 1) self.assertFalse(loader.called) self.assertTrue(os.path.exists(path)) with open(path) as f: self.assertEqual(f.read(), 'test') finally: shutil.rmtree(dir_path)
def test_file_exists(self): creator = mock.Mock() loader = mock.Mock() # This file always does not exists as the directory is new. path = os.path.join(self.dir_path, 'cache') with self.assertRaises(RuntimeError): download.cache_or_load_file(path, creator, loader)
def test_cache_exists(self): creator = mock.Mock() loader = mock.Mock() file_desc, file_name = tempfile.mkstemp() try: download.cache_or_load_file(file_name, creator, loader) finally: os.close(file_desc) os.remove(file_name) self.assertFalse(creator.called) loader.assert_called_once_with(file_name)
def get_text_classification_dataset(key) -> Dict[str, Union[List, easyfile.CsvFile]]: url = urls[key] root = download.get_cache_directory(os.path.join('datasets', 'text_classification', key)) def list_creator(path): dataset = {} archive_path = gdown.cached_download(url) maxsize = sys.maxsize while True: try: csv.field_size_limit(maxsize) break except OverflowError: maxsize = int(maxsize / 10) csv.field_size_limit(maxsize) with tarfile.open(archive_path, 'r') as archive: for split in ('train', 'test'): filename = f'{key}_csv/{split}.csv' print(f'Processing {filename}...') reader = csv.reader( io.TextIOWrapper(archive.extractfile(filename), encoding='utf-8')) dataset[split] = list(reader) with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset def easyfile_creator(path): dataset = {} archive_path = gdown.cached_download(url) with tarfile.open(archive_path, 'r') as archive: print(f'Extracting to {root}...') archive.extractall(root) dataset = {} for split in ('train', 'test'): filename = f'{key}_csv/{split}.csv' dataset[split] = easyfile.CsvFile(os.path.join(root, filename)) with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset def loader(path): with io.open(path, 'rb') as f: return pickle.load(f) assert key in urls if key in ('ag_news', 'dpbedia'): creator = list_creator else: creator = easyfile_creator pkl_path = os.path.join(root, f'{key}.pkl') return download.cache_or_load_file(pkl_path, creator, loader)
def get_scitldr(mode: str = "a") -> Dict[str, Any]: url = { "a": "https://raw.githubusercontent.com/allenai/scitldr/master/SciTLDR-Data/SciTLDR-A/{}.jsonl", "aic": "https://raw.githubusercontent.com/allenai/scitldr/master/SciTLDR-Data/SciTLDR-AIC/{}.jsonl", "full": "https://raw.githubusercontent.com/allenai/scitldr/master/SciTLDR-Data/SciTLDR-FullText/{}.jsonl", }[mode] root = download.get_cache_directory(os.path.join("datasets", "scitldr")) def creator(path): dataset = {} for split in ("train", "test", "dev"): d_path = gdown.cached_download(url.format(split)) dataset[split] = [] with open(d_path, "r") as _f: for line in _f.readlines(): dataset[split].append(json.loads(line)) with open(path, "wb") as _f: pickle.dump(dataset, _f) return dataset def loader(path): with open(path, "rb") as _f: return pickle.load(_f) pkl_path = os.path.join(root, f"scitldr_{mode}.pkl") return download.cache_or_load_file(pkl_path, creator, loader)
def get_snli() -> Dict[str, List[str]]: url = 'https://nlp.stanford.edu/projects/snli/snli_1.0.zip' root = download.get_cache_directory(os.path.join('datasets', 'snli')) def creator(path): archive_path = gdown.cached_download(url) with zipfile.ZipFile(archive_path, 'r') as archive: dataset = {} path2key = { 'snli_1.0/snli_1.0_train.jsonl': 'train', 'snli_1.0/snli_1.0_dev.jsonl': 'dev', 'snli_1.0/snli_1.0_test.jsonl': 'test', } for p, key in path2key.items(): print(f'Extracting {p}...') with archive.open(p) as f: lines = [json.loads(line.decode('utf-8')) for line in f] dataset[key] = lines with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset def loader(path): with io.open(path, 'rb') as f: return pickle.load(f) pkl_path = os.path.join(root, 'snil.pkl') return download.cache_or_load_file(pkl_path, creator, loader)
def get_cnn_dailymail() -> Dict[str, Tuple[arrayfiles.TextFile]]: url = 'https://s3.amazonaws.com/opennmt-models/Summary/cnndm.tar.gz' root = download.get_cache_directory( os.path.join('datasets', 'cnn_dailymail')) def creator(path): archive_path = gdown.cached_download(url) target_path = os.path.join(root, 'raw') with tarfile.open(archive_path, 'r') as archive: print(f'Extracting to {target_path}') archive.extractall(target_path) dataset = {} for split in ('train', 'dev', 'test'): src_path = f'{split if split != "dev" else "val"}.txt.src' tgt_path = f'{split if split != "dev" else "val"}.txt.tgt.tagged' dataset[split] = (arrayfiles.TextFile( os.path.join(target_path, src_path)), arrayfiles.TextFile( os.path.join(target_path, tgt_path))) with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset def loader(path): with io.open(path, 'rb') as f: return pickle.load(f) pkl_path = os.path.join(root, 'cnndm.pkl') return download.cache_or_load_file(pkl_path, creator, loader)
def get_small_parallel_enja() -> Dict[str, Tuple[List[str]]]: en_url = 'https://raw.githubusercontent.com/odashi/small_parallel_enja/master/{}.en' ja_url = 'https://raw.githubusercontent.com/odashi/small_parallel_enja/master/{}.ja' root = download.get_cache_directory( os.path.join('datasets', 'small_parallel_enja')) def creator(path): dataset = {} for split in ('train', 'dev', 'test'): en_path = gdown.cached_download(en_url.format(split)) ja_path = gdown.cached_download(ja_url.format(split)) with io.open(en_path, 'rt') as en, io.open(ja_path, 'rt') as ja: dataset[split] = [(x.rstrip(os.linesep), y.rstrip(os.linesep)) for x, y in zip(en, ja)] with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset def loader(path): with io.open(path, 'rb') as f: return pickle.load(f) pkl_path = os.path.join(root, 'enja.pkl') return download.cache_or_load_file(pkl_path, creator, loader)
def get_imdb() -> Dict[str, List[str]]: url = 'https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz' root = download.get_cache_directory(os.path.join('datasets')) def creator(path): archive_path = gdown.cached_download(url) with tarfile.open(archive_path, 'r') as archive: print(f'Extracting to {root}...') archive.extractall(root) extracted_path = os.path.join(root, 'aclImdb') dataset = {} for split in ('train', 'test'): pos_path = os.path.join(extracted_path, split, 'pos') neg_path = os.path.join(extracted_path, split, 'neg') dataset[split] = [x.path for x in os.scandir(pos_path) if x.is_file() and x.name.endswith('.txt')] + \ [x.path for x in os.scandir(neg_path) if x.is_file() and x.name.endswith('.txt')] with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset def loader(path): with io.open(path, 'rb') as f: return pickle.load(f) pkl_path = os.path.join(root, 'aclImdb', 'imdb.pkl') return download.cache_or_load_file(pkl_path, creator, loader)
def get_msr_paraphrase() -> Dict[str, List[Dict[str, str]]]: url = 'https://raw.githubusercontent.com/wasiahmad/paraphrase_identification/master/dataset/msr-paraphrase-corpus/msr_paraphrase_{}.txt' # NOQA root = download.get_cache_directory(os.path.join('datasets', 'msr_paraphrase')) def creator(path): dataset = {} fieldnames = ('quality', 'id1', 'id2', 'string1', 'string2') for split in ('train', 'test'): data_path = gdown.cached_download(url.format(split)) with io.open(data_path, 'r', encoding='utf-8') as f: f.readline() # skip header reader = csv.DictReader(f, delimiter='\t', fieldnames=fieldnames) dataset[split] = [dict(row) for row in reader] with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset def loader(path): with io.open(path, 'rb') as f: return pickle.load(f) pkl_path = os.path.join(root, 'msr_paraphrase.pkl') return download.cache_or_load_file(pkl_path, creator, loader)
def get_wikitext(name: str) -> Dict[str, Union[arrayfiles.TextFile, List]]: url = f'https://s3.amazonaws.com/research.metamind.io/wikitext/{name}-v1.zip' root = download.get_cache_directory(os.path.join('datasets', 'wikitext')) def list_creator(path): archive_path = gdown.cached_download(url) with zipfile.ZipFile(archive_path, 'r') as archive: dataset = {} path2key = { f'{name}/wiki.train.tokens': 'train', f'{name}/wiki.valid.tokens': 'dev', f'{name}/wiki.test.tokens': 'test' } for p, key in path2key.items(): print(f'Extracting {p}...') with archive.open(p) as f: lines = [ line.decode('utf-8').rstrip(os.linesep) for line in f ] dataset[key] = lines with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset def easyfile_creator(path): archive_path = gdown.cached_download(url) with zipfile.ZipFile(archive_path, 'r') as archive: print(f'Extracting to {root}...') archive.extractall(root) dataset = {} for split in ('train', 'dev', 'test'): filename = 'wiki.{}.tokens'.format( split if split != 'dev' else 'valid') dataset[split] = arrayfiles.TextFile( os.path.join(root, name, filename)) with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset def loader(path): with io.open(path, 'rb') as f: return pickle.load(f) assert name == 'wikitext-2' or name == 'wikitext-103' if name == 'wikitext-2': creator = list_creator elif name == 'wikitext-103': creator = easyfile_creator pkl_path = os.path.join(root, f'{name.replace("-", "")}.pkl') return download.cache_or_load_file(pkl_path, creator, loader)
def get_commonsenseqa() -> Dict[str, List[str]]: train_url = "https://s3.amazonaws.com/commensenseqa/train_rand_split.jsonl" dev_url = "https://s3.amazonaws.com/commensenseqa/dev_rand_split.jsonl" test_url = "https://s3.amazonaws.com/commensenseqa/test_rand_split_no_answers.jsonl" root = download.get_cache_directory( os.path.join("datasets", "commonsenseqa")) def creator(path): train_path = gdown.cached_download(train_url) dev_path = gdown.cached_download(dev_url) test_path = gdown.cached_download(test_url) dataset = {} for split in ("train", "dev", "test"): data_path = { "train": train_path, "dev": dev_path, "test": test_path }[split] with io.open(data_path, "rt", encoding="utf-8") as f: data = [json.loads(line) for line in f.readlines()] temp = [] for x in data: answer_key = x["answerKey"] if split != "test" else "" options = { choice["label"]: choice["text"] for choice in x["question"]["choices"] } stem = x["question"]["stem"] temp.append({ "id": x["id"], "answer_key": answer_key, "options": options, "stem": stem }) dataset[split] = temp with io.open(path, "wb") as f: pickle.dump(dataset, f) return dataset def loader(path): with io.open(path, "rb") as f: return pickle.load(f) pkl_path = os.path.join(root, "commonsenseqa.pkl") return download.cache_or_load_file(pkl_path, creator, loader)
def get_penn_treebank() -> Dict[str, List[str]]: url = 'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.{}.txt' root = download.get_cache_directory(os.path.join('datasets', 'ptb')) def creator(path): dataset = {} for split in ('train', 'dev', 'test'): data_path = download.cached_download( url.format(split if split != 'dev' else 'valid')) with io.open(data_path, 'rt') as f: dataset[split] = [line.rstrip(os.linesep) for line in f] with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset def loader(path): with io.open(path, 'rb') as f: return pickle.load(f) pkl_path = os.path.join(root, 'ptb.pkl') return download.cache_or_load_file(pkl_path, creator, loader)
def get_squad(version: int) -> Dict[str, List]: version_str = 'v1.1' if version == 1 else 'v2.0' train_url = f'https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/train-{version_str}.json' dev_url = f'https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/dev-{version_str}.json' root = download.get_cache_directory(os.path.join('datasets', 'squad')) def creator(path): train_path = gdown.cached_download(train_url) dev_path = gdown.cached_download(dev_url) dataset = {} for split in ('train', 'dev'): data_path = train_path if split == 'train' else dev_path with io.open(data_path, 'rt', encoding='utf-8') as f: data = json.load(f)['data'] temp = [] for x in data: title = x['title'] for paragraph in x['paragraphs']: context = paragraph['context'] for qa in paragraph['qas']: qa['title'] = title qa['context'] = context temp.append(qa) dataset[split] = temp with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset def loader(path): with io.open(path, 'rb') as f: return pickle.load(f) pkl_path = os.path.join(root, f'squad.{version_str}.pkl') return download.cache_or_load_file(pkl_path, creator, loader)
def get_wmt14() -> Dict[str, Tuple[arrayfiles.TextFile]]: url = 'https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8' root = download.get_cache_directory(os.path.join('datasets', 'wmt14')) def creator(path): archive_path = gdown.cached_download(url) target_path = os.path.join(root, 'raw') with tarfile.open(archive_path, 'r') as archive: print(f'Extracting to {target_path}') archive.extractall(target_path) split2filename = { 'train': 'train.tok.clean.bpe.32000', 'dev': 'newstest2013.tok.bpe.32000', 'test': 'newstest2014.tok.bpe.32000' } dataset = {} for split, filename in split2filename.items(): src_path = f'{filename}.en' tgt_path = f'{filename}.de' dataset[split] = (arrayfiles.TextFile( os.path.join(target_path, src_path)), arrayfiles.TextFile( os.path.join(target_path, tgt_path))) with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset def loader(path): with io.open(path, 'rb') as f: return pickle.load(f) pkl_path = os.path.join(root, 'cnndm.pkl') return download.cache_or_load_file(pkl_path, creator, loader)
def get_conll2000() -> Dict[str, List[str]]: url = 'https://www.clips.uantwerpen.be/conll2000/chunking/{}.txt.gz' root = download.get_cache_directory(os.path.join('datasets', 'conll2000')) def creator(path): dataset = {} for split in ('train', 'test'): data_path = gdown.cached_download(url.format(split)) with gzip.open(data_path) as f: data = f.read().decode('utf-8').split('\n\n') dataset[split] = data with io.open(path, 'wb') as f: pickle.dump(dataset, f) return dataset def loader(path): with io.open(path, 'rb') as f: return pickle.load(f) pkl_path = os.path.join(root, 'conll2000.pkl') return download.cache_or_load_file(pkl_path, creator, loader)