예제 #1
0
def get_imdb() -> Dict[str, List[str]]:

    url = 'https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'
    root = download.get_cache_directory(os.path.join('datasets'))

    def creator(path):
        archive_path = gdown.cached_download(url)
        with tarfile.open(archive_path, 'r') as archive:
            print(f'Extracting to {root}...')
            archive.extractall(root)

        extracted_path = os.path.join(root, 'aclImdb')

        dataset = {}
        for split in ('train', 'test'):
            pos_path = os.path.join(extracted_path, split, 'pos')
            neg_path = os.path.join(extracted_path, split, 'neg')
            dataset[split] = [x.path for x in os.scandir(pos_path)
                              if x.is_file() and x.name.endswith('.txt')] + \
                             [x.path for x in os.scandir(neg_path)
                              if x.is_file() and x.name.endswith('.txt')]

        with io.open(path, 'wb') as f:
            pickle.dump(dataset, f)
        return dataset

    def loader(path):
        with io.open(path, 'rb') as f:
            return pickle.load(f)

    pkl_path = os.path.join(root, 'aclImdb', 'imdb.pkl')
    return download.cache_or_load_file(pkl_path, creator, loader)
예제 #2
0
def get_text_classification_dataset(key) -> Dict[str, Union[List, easyfile.CsvFile]]:

    url = urls[key]
    root = download.get_cache_directory(os.path.join('datasets', 'text_classification', key))

    def list_creator(path):
        dataset = {}
        archive_path = gdown.cached_download(url)

        maxsize = sys.maxsize
        while True:
            try:
                csv.field_size_limit(maxsize)
                break
            except OverflowError:
                maxsize = int(maxsize / 10)
        csv.field_size_limit(maxsize)

        with tarfile.open(archive_path, 'r') as archive:
            for split in ('train', 'test'):
                filename = f'{key}_csv/{split}.csv'
                print(f'Processing {filename}...')
                reader = csv.reader(
                    io.TextIOWrapper(archive.extractfile(filename), encoding='utf-8'))
                dataset[split] = list(reader)

        with io.open(path, 'wb') as f:
            pickle.dump(dataset, f)
        return dataset

    def easyfile_creator(path):
        dataset = {}
        archive_path = gdown.cached_download(url)

        with tarfile.open(archive_path, 'r') as archive:
            print(f'Extracting to {root}...')
            archive.extractall(root)

        dataset = {}
        for split in ('train', 'test'):
            filename = f'{key}_csv/{split}.csv'
            dataset[split] = easyfile.CsvFile(os.path.join(root, filename))

        with io.open(path, 'wb') as f:
            pickle.dump(dataset, f)
        return dataset

    def loader(path):
        with io.open(path, 'rb') as f:
            return pickle.load(f)

    assert key in urls

    if key in ('ag_news', 'dpbedia'):
        creator = list_creator
    else:
        creator = easyfile_creator

    pkl_path = os.path.join(root, f'{key}.pkl')
    return download.cache_or_load_file(pkl_path, creator, loader)
예제 #3
0
def get_msr_paraphrase() -> Dict[str, List[Dict[str, str]]]:

    url = 'https://raw.githubusercontent.com/wasiahmad/paraphrase_identification/master/dataset/msr-paraphrase-corpus/msr_paraphrase_{}.txt'  # NOQA
    root = download.get_cache_directory(os.path.join('datasets', 'msr_paraphrase'))

    def creator(path):
        dataset = {}
        fieldnames = ('quality', 'id1', 'id2', 'string1', 'string2')
        for split in ('train', 'test'):
            data_path = gdown.cached_download(url.format(split))
            with io.open(data_path, 'r', encoding='utf-8') as f:
                f.readline()  # skip header
                reader = csv.DictReader(f, delimiter='\t', fieldnames=fieldnames)
                dataset[split] = [dict(row) for row in reader]

        with io.open(path, 'wb') as f:
            pickle.dump(dataset, f)
        return dataset

    def loader(path):
        with io.open(path, 'rb') as f:
            return pickle.load(f)

    pkl_path = os.path.join(root, 'msr_paraphrase.pkl')
    return download.cache_or_load_file(pkl_path, creator, loader)
예제 #4
0
    def setUp(self):
        cache_fp = tempfile.NamedTemporaryFile()
        self.cache_fp = cache_fp

        cached_download_patcher = patch(
            'lineflow.datasets.cnn_dailymail.cached_download')
        cached_download_mock = cached_download_patcher.start()
        cached_download_mock.side_effect = lambda url: cache_fp.name
        self.cached_download_patcher = cached_download_patcher
        self.cached_download_mock = cached_download_mock

        tarfile_patcher = patch('lineflow.datasets.cnn_dailymail.tarfile')
        tarfile_mock = tarfile_patcher.start()
        self.tarfile_patcher = tarfile_patcher
        self.tarfile_mock = tarfile_mock

        exists_patcher = patch('lineflow.datasets.cnn_dailymail.Path.exists')
        exists_mock = exists_patcher.start()
        exists_mock.return_value = True
        self.exists_patcher = exists_patcher
        self.exists_mock = exists_mock

        init_patcher = patch(
            'lineflow.datasets.seq2seq.Seq2SeqDataset.__init__')
        init_mock = init_patcher.start()
        self.init_patcher = init_patcher
        self.init_mock = init_mock

        self.cache_dir = Path(get_cache_directory('cnndm'))
예제 #5
0
파일: snli.py 프로젝트: tofunlp/lineflow
def get_snli() -> Dict[str, List[str]]:

    url = 'https://nlp.stanford.edu/projects/snli/snli_1.0.zip'
    root = download.get_cache_directory(os.path.join('datasets', 'snli'))

    def creator(path):
        archive_path = gdown.cached_download(url)
        with zipfile.ZipFile(archive_path, 'r') as archive:
            dataset = {}
            path2key = {
                'snli_1.0/snli_1.0_train.jsonl': 'train',
                'snli_1.0/snli_1.0_dev.jsonl': 'dev',
                'snli_1.0/snli_1.0_test.jsonl': 'test',
            }
            for p, key in path2key.items():
                print(f'Extracting {p}...')
                with archive.open(p) as f:
                    lines = [json.loads(line.decode('utf-8')) for line in f]
                dataset[key] = lines

        with io.open(path, 'wb') as f:
            pickle.dump(dataset, f)
        return dataset

    def loader(path):
        with io.open(path, 'rb') as f:
            return pickle.load(f)

    pkl_path = os.path.join(root, 'snil.pkl')
    return download.cache_or_load_file(pkl_path, creator, loader)
예제 #6
0
파일: scitldr.py 프로젝트: tofunlp/lineflow
def get_scitldr(mode: str = "a") -> Dict[str, Any]:

    url = {
        "a": "https://raw.githubusercontent.com/allenai/scitldr/master/SciTLDR-Data/SciTLDR-A/{}.jsonl",
        "aic": "https://raw.githubusercontent.com/allenai/scitldr/master/SciTLDR-Data/SciTLDR-AIC/{}.jsonl",
        "full": "https://raw.githubusercontent.com/allenai/scitldr/master/SciTLDR-Data/SciTLDR-FullText/{}.jsonl",
    }[mode]

    root = download.get_cache_directory(os.path.join("datasets", "scitldr"))

    def creator(path):
        dataset = {}
        for split in ("train", "test", "dev"):
            d_path = gdown.cached_download(url.format(split))
            dataset[split] = []
            with open(d_path, "r") as _f:
                for line in _f.readlines():
                    dataset[split].append(json.loads(line))

        with open(path, "wb") as _f:
            pickle.dump(dataset, _f)
        return dataset

    def loader(path):
        with open(path, "rb") as _f:
            return pickle.load(_f)

    pkl_path = os.path.join(root, f"scitldr_{mode}.pkl")
    return download.cache_or_load_file(pkl_path, creator, loader)
예제 #7
0
def get_cnn_dailymail() -> Dict[str, Tuple[arrayfiles.TextFile]]:

    url = 'https://s3.amazonaws.com/opennmt-models/Summary/cnndm.tar.gz'
    root = download.get_cache_directory(
        os.path.join('datasets', 'cnn_dailymail'))

    def creator(path):
        archive_path = gdown.cached_download(url)
        target_path = os.path.join(root, 'raw')
        with tarfile.open(archive_path, 'r') as archive:
            print(f'Extracting to {target_path}')
            archive.extractall(target_path)

        dataset = {}
        for split in ('train', 'dev', 'test'):
            src_path = f'{split if split != "dev" else "val"}.txt.src'
            tgt_path = f'{split if split != "dev" else "val"}.txt.tgt.tagged'
            dataset[split] = (arrayfiles.TextFile(
                os.path.join(target_path, src_path)),
                              arrayfiles.TextFile(
                                  os.path.join(target_path, tgt_path)))

        with io.open(path, 'wb') as f:
            pickle.dump(dataset, f)
        return dataset

    def loader(path):
        with io.open(path, 'rb') as f:
            return pickle.load(f)

    pkl_path = os.path.join(root, 'cnndm.pkl')
    return download.cache_or_load_file(pkl_path, creator, loader)
예제 #8
0
def get_small_parallel_enja() -> Dict[str, Tuple[List[str]]]:

    en_url = 'https://raw.githubusercontent.com/odashi/small_parallel_enja/master/{}.en'
    ja_url = 'https://raw.githubusercontent.com/odashi/small_parallel_enja/master/{}.ja'
    root = download.get_cache_directory(
        os.path.join('datasets', 'small_parallel_enja'))

    def creator(path):
        dataset = {}
        for split in ('train', 'dev', 'test'):
            en_path = gdown.cached_download(en_url.format(split))
            ja_path = gdown.cached_download(ja_url.format(split))
            with io.open(en_path, 'rt') as en, io.open(ja_path, 'rt') as ja:
                dataset[split] = [(x.rstrip(os.linesep), y.rstrip(os.linesep))
                                  for x, y in zip(en, ja)]

        with io.open(path, 'wb') as f:
            pickle.dump(dataset, f)
        return dataset

    def loader(path):
        with io.open(path, 'rb') as f:
            return pickle.load(f)

    pkl_path = os.path.join(root, 'enja.pkl')
    return download.cache_or_load_file(pkl_path, creator, loader)
예제 #9
0
def get_wikitext(name: str) -> Dict[str, Union[arrayfiles.TextFile, List]]:

    url = f'https://s3.amazonaws.com/research.metamind.io/wikitext/{name}-v1.zip'
    root = download.get_cache_directory(os.path.join('datasets', 'wikitext'))

    def list_creator(path):
        archive_path = gdown.cached_download(url)
        with zipfile.ZipFile(archive_path, 'r') as archive:
            dataset = {}
            path2key = {
                f'{name}/wiki.train.tokens': 'train',
                f'{name}/wiki.valid.tokens': 'dev',
                f'{name}/wiki.test.tokens': 'test'
            }
            for p, key in path2key.items():
                print(f'Extracting {p}...')
                with archive.open(p) as f:
                    lines = [
                        line.decode('utf-8').rstrip(os.linesep) for line in f
                    ]
                dataset[key] = lines

        with io.open(path, 'wb') as f:
            pickle.dump(dataset, f)
        return dataset

    def easyfile_creator(path):
        archive_path = gdown.cached_download(url)
        with zipfile.ZipFile(archive_path, 'r') as archive:
            print(f'Extracting to {root}...')
            archive.extractall(root)

        dataset = {}
        for split in ('train', 'dev', 'test'):
            filename = 'wiki.{}.tokens'.format(
                split if split != 'dev' else 'valid')
            dataset[split] = arrayfiles.TextFile(
                os.path.join(root, name, filename))

        with io.open(path, 'wb') as f:
            pickle.dump(dataset, f)
        return dataset

    def loader(path):
        with io.open(path, 'rb') as f:
            return pickle.load(f)

    assert name == 'wikitext-2' or name == 'wikitext-103'

    if name == 'wikitext-2':
        creator = list_creator
    elif name == 'wikitext-103':
        creator = easyfile_creator

    pkl_path = os.path.join(root, f'{name.replace("-", "")}.pkl')
    return download.cache_or_load_file(pkl_path, creator, loader)
예제 #10
0
def get_commonsenseqa() -> Dict[str, List[str]]:
    train_url = "https://s3.amazonaws.com/commensenseqa/train_rand_split.jsonl"
    dev_url = "https://s3.amazonaws.com/commensenseqa/dev_rand_split.jsonl"
    test_url = "https://s3.amazonaws.com/commensenseqa/test_rand_split_no_answers.jsonl"
    root = download.get_cache_directory(
        os.path.join("datasets", "commonsenseqa"))

    def creator(path):
        train_path = gdown.cached_download(train_url)
        dev_path = gdown.cached_download(dev_url)
        test_path = gdown.cached_download(test_url)

        dataset = {}
        for split in ("train", "dev", "test"):
            data_path = {
                "train": train_path,
                "dev": dev_path,
                "test": test_path
            }[split]
            with io.open(data_path, "rt", encoding="utf-8") as f:
                data = [json.loads(line) for line in f.readlines()]
            temp = []
            for x in data:
                answer_key = x["answerKey"] if split != "test" else ""
                options = {
                    choice["label"]: choice["text"]
                    for choice in x["question"]["choices"]
                }
                stem = x["question"]["stem"]
                temp.append({
                    "id": x["id"],
                    "answer_key": answer_key,
                    "options": options,
                    "stem": stem
                })
            dataset[split] = temp

        with io.open(path, "wb") as f:
            pickle.dump(dataset, f)
        return dataset

    def loader(path):
        with io.open(path, "rb") as f:
            return pickle.load(f)

    pkl_path = os.path.join(root, "commonsenseqa.pkl")
    return download.cache_or_load_file(pkl_path, creator, loader)
예제 #11
0
def get_penn_treebank() -> Dict[str, List[str]]:

    url = 'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.{}.txt'
    root = download.get_cache_directory(os.path.join('datasets', 'ptb'))

    def creator(path):
        dataset = {}
        for split in ('train', 'dev', 'test'):
            data_path = download.cached_download(
                url.format(split if split != 'dev' else 'valid'))
            with io.open(data_path, 'rt') as f:
                dataset[split] = [line.rstrip(os.linesep) for line in f]

        with io.open(path, 'wb') as f:
            pickle.dump(dataset, f)
        return dataset

    def loader(path):
        with io.open(path, 'rb') as f:
            return pickle.load(f)

    pkl_path = os.path.join(root, 'ptb.pkl')
    return download.cache_or_load_file(pkl_path, creator, loader)
예제 #12
0
파일: squad.py 프로젝트: tofunlp/lineflow
def get_squad(version: int) -> Dict[str, List]:
    version_str = 'v1.1' if version == 1 else 'v2.0'

    train_url = f'https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/train-{version_str}.json'
    dev_url = f'https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/dev-{version_str}.json'
    root = download.get_cache_directory(os.path.join('datasets', 'squad'))

    def creator(path):
        train_path = gdown.cached_download(train_url)
        dev_path = gdown.cached_download(dev_url)

        dataset = {}
        for split in ('train', 'dev'):
            data_path = train_path if split == 'train' else dev_path
            with io.open(data_path, 'rt', encoding='utf-8') as f:
                data = json.load(f)['data']
            temp = []
            for x in data:
                title = x['title']
                for paragraph in x['paragraphs']:
                    context = paragraph['context']
                    for qa in paragraph['qas']:
                        qa['title'] = title
                        qa['context'] = context
                        temp.append(qa)
            dataset[split] = temp

        with io.open(path, 'wb') as f:
            pickle.dump(dataset, f)
        return dataset

    def loader(path):
        with io.open(path, 'rb') as f:
            return pickle.load(f)

    pkl_path = os.path.join(root, f'squad.{version_str}.pkl')
    return download.cache_or_load_file(pkl_path, creator, loader)
예제 #13
0
def get_wmt14() -> Dict[str, Tuple[arrayfiles.TextFile]]:

    url = 'https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8'
    root = download.get_cache_directory(os.path.join('datasets', 'wmt14'))

    def creator(path):
        archive_path = gdown.cached_download(url)
        target_path = os.path.join(root, 'raw')
        with tarfile.open(archive_path, 'r') as archive:
            print(f'Extracting to {target_path}')
            archive.extractall(target_path)

        split2filename = {
            'train': 'train.tok.clean.bpe.32000',
            'dev': 'newstest2013.tok.bpe.32000',
            'test': 'newstest2014.tok.bpe.32000'
        }
        dataset = {}
        for split, filename in split2filename.items():
            src_path = f'{filename}.en'
            tgt_path = f'{filename}.de'
            dataset[split] = (arrayfiles.TextFile(
                os.path.join(target_path, src_path)),
                              arrayfiles.TextFile(
                                  os.path.join(target_path, tgt_path)))

        with io.open(path, 'wb') as f:
            pickle.dump(dataset, f)
        return dataset

    def loader(path):
        with io.open(path, 'rb') as f:
            return pickle.load(f)

    pkl_path = os.path.join(root, 'cnndm.pkl')
    return download.cache_or_load_file(pkl_path, creator, loader)
예제 #14
0
def get_conll2000() -> Dict[str, List[str]]:

    url = 'https://www.clips.uantwerpen.be/conll2000/chunking/{}.txt.gz'
    root = download.get_cache_directory(os.path.join('datasets', 'conll2000'))

    def creator(path):
        dataset = {}
        for split in ('train', 'test'):
            data_path = gdown.cached_download(url.format(split))
            with gzip.open(data_path) as f:
                data = f.read().decode('utf-8').split('\n\n')

            dataset[split] = data

        with io.open(path, 'wb') as f:
            pickle.dump(dataset, f)
        return dataset

    def loader(path):
        with io.open(path, 'rb') as f:
            return pickle.load(f)

    pkl_path = os.path.join(root, 'conll2000.pkl')
    return download.cache_or_load_file(pkl_path, creator, loader)
예제 #15
0
    def __init__(self, split: str = 'train') -> None:
        path = cached_download(CNN_DAILYMAIL_URL)
        tf = tarfile.open(path, 'r')
        cache_dir = Path(get_cache_directory('cnndm'))
        if not all((cache_dir / p).exists() for p in ALL):
            print(f'Extracting from {path}...')
            tf.extractall(cache_dir)

        if split == 'train':
            src_path = cache_dir / TRAIN_SOURCE_NAME
            tgt_path = cache_dir / TRAIN_TARGET_NAME
        elif split == 'dev':
            src_path = cache_dir / VAL_SOURCE_NAME
            tgt_path = cache_dir / VAL_TARGET_NAME
        elif split == 'test':
            src_path = cache_dir / TEST_SOURCE_NAME
            tgt_path = cache_dir / TEST_TARGET_NAME
        else:
            raise ValueError(
                f"only 'train', 'dev' and 'test' are valid for 'split', but '{split}' is given."
            )

        super().__init__(source_file_path=str(src_path),
                         target_file_path=str(tgt_path))
예제 #16
0
 def test_fails_to_make_directory(self, f):
     f.side_effect = OSError()
     with self.assertRaises(OSError):
         download.get_cache_directory('/lineflow_test_cache', True)
예제 #17
0
 def test_get_cache_directory(self):
     root = download.get_cache_root()
     path = download.get_cache_directory('test', False)
     self.assertEqual(path, os.path.join(root, 'test'))
예제 #18
0
 def test_fails_to_make_directory(self):
     with mock.patch('os.makedirs') as f:
         f.side_effect = OSError()
         with self.assertRaises(OSError):
             download.get_cache_directory('/lineflow_test_cache', True)