Esempio n. 1
0
    def creator(path):
        dataset = {}
        for split in ('train', 'dev', 'test'):
            en_path = download.cached_download(en_url.format(split))
            ja_path = download.cached_download(ja_url.format(split))
            with io.open(en_path, 'rt') as en, io.open(ja_path, 'rt') as ja:
                dataset[split] = [(x.rstrip(os.linesep), y.rstrip(os.linesep))
                                  for x, y in zip(en, ja)]

        with io.open(path, 'wb') as f:
            pickle.dump(dataset, f)
        return dataset
Esempio n. 2
0
    def __init__(self, split: str = 'train') -> None:
        path = cached_download(IMDB_URL)
        tf = tarfile.open(path, 'r')
        cache_dir = Path(get_cache_root())
        if not all((cache_dir / p).exists() for p in ALL):
            print(f'Extracting from {path}...')
            tf.extractall(cache_dir)

        if split == 'train':
            pos_dir = f'{cache_dir / TRAIN_DIR}/pos'
            neg_dir = f'{cache_dir / TRAIN_DIR}/neg'
        elif split == 'test':
            pos_dir = f'{cache_dir / TEST_DIR}/pos'
            neg_dir = f'{cache_dir / TEST_DIR}/neg'
        else:
            raise ValueError(
                f"only 'train' and 'test' are valid for 'split', but '{split}' is given."
            )

        path = list(
            chain(Path(pos_dir).glob('*.txt'),
                  Path(neg_dir).glob('*.txt')))

        def map_func(x: Path) -> Tuple[str, int]:
            string = x.read_text()
            label = 0 if 'pos' in str(x) else 1
            return (string, label)

        super().__init__(path, map_func)
Esempio n. 3
0
 def test_cache_exists(self):
     with mock.patch('os.path.exists') as f:
         f.return_value = True
         url = 'https://example.com'
         path = download.cached_download(url)
         self.assertEqual(
             path,
             f'{self.temp_dir}/_dl_cache/{hashlib.md5(url.encode("utf-8")).hexdigest()}'
         )
Esempio n. 4
0
    def creator(path):
        dataset = {}
        for split in ('train', 'dev', 'test'):
            data_path = download.cached_download(
                url.format(split if split != 'dev' else 'valid'))
            with io.open(data_path, 'rt') as f:
                dataset[split] = [line.rstrip(os.linesep) for line in f]

        with io.open(path, 'wb') as f:
            pickle.dump(dataset, f)
        return dataset
Esempio n. 5
0
    def easyfile_creator(path):
        archive_path = download.cached_download(url)
        with zipfile.ZipFile(archive_path, 'r') as archive:
            print(f'Extracting to {root}...')
            archive.extractall(root)

        dataset = {}
        for split in ('train', 'dev', 'test'):
            filename = 'wiki.{}.tokens'.format(split if split != 'dev' else 'valid')
            dataset[split] = easyfile.TextFile(os.path.join(root, name, filename))

        with io.open(path, 'wb') as f:
            pickle.dump(dataset, f)
        return dataset
Esempio n. 6
0
    def __init__(self, split: str = 'train', version: int = 1) -> None:
        if version == 1:
            train_url = TRAIN_V1_URL
            dev_url = DEV_V1_URL
        elif version == 2:
            train_url = TRAIN_V2_URL
            dev_url = DEV_V2_URL
        else:
            raise ValueError(
                f"only 1 and 2 are valid for 'version', but {version} is given."
            )

        if split == 'train':
            path = cached_download(train_url)
        elif split == 'dev':
            path = cached_download(dev_url)
        else:
            raise ValueError(
                f"only 'train' and 'dev' are valid for 'split', but '{split}' is given."
            )

        dataset = RandomAccessText(path)

        super().__init__(dataset, json.loads)
Esempio n. 7
0
    def creator(path):
        dataset = {}
        fieldnames = ('quality', 'id1', 'id2', 'string1', 'string2')
        for split in ('train', 'test'):
            data_path = download.cached_download(url.format(split))
            with io.open(data_path, 'r', encoding='utf-8') as f:
                f.readline()  # skip header
                reader = csv.DictReader(f,
                                        delimiter='\t',
                                        fieldnames=fieldnames)
                dataset[split] = [dict(row) for row in reader]

        with io.open(path, 'wb') as f:
            pickle.dump(dataset, f)
        return dataset
Esempio n. 8
0
    def list_creator(path):
        archive_path = download.cached_download(url)
        with zipfile.ZipFile(archive_path, 'r') as archive:
            dataset = {}
            path2key = {f'{name}/wiki.train.tokens': 'train',
                        f'{name}/wiki.valid.tokens': 'dev',
                        f'{name}/wiki.test.tokens': 'test'}
            for p, key in path2key.items():
                print(f'Extracting {p}...')
                with archive.open(p) as f:
                    lines = [line.decode('utf-8').rstrip(os.linesep) for line in f]
                dataset[key] = lines

        with io.open(path, 'wb') as f:
            pickle.dump(dataset, f)
        return dataset
Esempio n. 9
0
    def creator(path):
        archive_path = download.cached_download(url)
        target_path = os.path.join(root, 'raw')
        with tarfile.open(archive_path, 'r') as archive:
            print(f'Extracting to {target_path}')
            archive.extractall(target_path)

        dataset = {}
        for split in ('train', 'dev', 'test'):
            src_path = f'{split if split != "dev" else "val"}.txt.src'
            tgt_path = f'{split if split != "dev" else "val"}.txt.tgt.tagged'
            dataset[split] = (easyfile.TextFile(
                os.path.join(target_path, src_path)),
                              easyfile.TextFile(
                                  os.path.join(target_path, tgt_path)))

        with io.open(path, 'wb') as f:
            pickle.dump(dataset, f)
        return dataset
Esempio n. 10
0
    def __init__(self, split: str = 'train') -> None:
        if split == 'train':
            en_path = cached_download(TRAIN_EN_URL)
            ja_path = cached_download(TRAIN_JA_URL)
        elif split == 'dev':
            en_path = cached_download(DEV_EN_URL)
            ja_path = cached_download(DEV_JA_URL)
        elif split == 'test':
            en_path = cached_download(TEST_EN_URL)
            ja_path = cached_download(TEST_JA_URL)
        else:
            raise ValueError(f"only 'train', 'dev' and 'test' are valid for 'split', but '{split}' is given.")

        super().__init__(source_file_path=en_path, target_file_path=ja_path)
Esempio n. 11
0
    def test_cached_download(self):
        with mock.patch('urllib.request.urlretrieve') as f:
            def urlretrieve(url, path):
                with open(path, 'w') as f:
                    f.write('test')
            f.side_effect = urlretrieve

            cache_path = download.cached_download('https://example.com')

        self.assertEqual(f.call_count, 1)
        args, kwargs = f.call_args
        self.assertEqual(kwargs, {})
        self.assertEqual(len(args), 2)
        # The second argument is a temporary path, and it is removed
        self.assertEqual(args[0], 'https://example.com')

        self.assertTrue(os.path.exists(cache_path))
        with open(cache_path) as f:
            stored_data = f.read()
        self.assertEqual(stored_data, 'test')
Esempio n. 12
0
    def creator(path):
        archive_path = download.cached_download(url)
        with tarfile.open(archive_path, 'r') as archive:
            print(f'Extracting to {root}...')
            archive.extractall(root)

        extracted_path = os.path.join(root, 'aclImdb')

        dataset = {}
        for split in ('train', 'test'):
            pos_path = os.path.join(extracted_path, split, 'pos')
            neg_path = os.path.join(extracted_path, split, 'neg')
            dataset[split] = [x.path for x in os.scandir(pos_path)
                              if x.is_file() and x.name.endswith('.txt')] + \
                             [x.path for x in os.scandir(neg_path)
                              if x.is_file() and x.name.endswith('.txt')]

        with io.open(path, 'wb') as f:
            pickle.dump(dataset, f)
        return dataset
Esempio n. 13
0
    def __init__(self, split: str = 'train') -> None:
        path = cached_download(CNN_DAILYMAIL_URL)
        tf = tarfile.open(path, 'r')
        cache_dir = Path(get_cache_directory('cnndm'))
        if not all((cache_dir / p).exists() for p in ALL):
            print(f'Extracting from {path}...')
            tf.extractall(cache_dir)

        if split == 'train':
            src_path = cache_dir / TRAIN_SOURCE_NAME
            tgt_path = cache_dir / TRAIN_TARGET_NAME
        elif split == 'dev':
            src_path = cache_dir / VAL_SOURCE_NAME
            tgt_path = cache_dir / VAL_TARGET_NAME
        elif split == 'test':
            src_path = cache_dir / TEST_SOURCE_NAME
            tgt_path = cache_dir / TEST_TARGET_NAME
        else:
            raise ValueError(
                f"only 'train', 'dev' and 'test' are valid for 'split', but '{split}' is given."
            )

        super().__init__(source_file_path=str(src_path),
                         target_file_path=str(tgt_path))
Esempio n. 14
0
 def test_file_exists(self):
     # Make an empty file which has the same name as the cache directory
     with open(os.path.join(self.temp_dir, '_dl_cache'), 'w'):
         pass
     with self.assertRaises(OSError):
         download.cached_download('https://example.com')
Esempio n. 15
0
 def test_fails_to_make_directory(self, f):
     f.side_effect = OSError()
     with self.assertRaises(OSError):
         download.cached_download('https://example.com')
Esempio n. 16
0
 def test_fails_to_make_directory(self):
     with mock.patch('os.makedirs') as f:
         f.side_effect = OSError()
         with self.assertRaises(OSError):
             download.cached_download('https://example.com')