Example #1
0
 def test_set_cache_root(self):
     orig_root = download.get_cache_root()
     new_root = '/tmp/cache'
     try:
         download.set_cache_root(new_root)
         self.assertEqual(download.get_cache_root(), new_root)
     finally:
         download.set_cache_root(orig_root)
 def setUpClass(cls):
     cls.default_cache_root = download.get_cache_root()
     cls.temp_dir = tempfile.mkdtemp()
     download.set_cache_root(cls.temp_dir)
     cls.patcher = mock.patch('lineflow.datasets.text_classification.sys.maxsize',
                              int(sys.float_info.max))
     cls.patcher.start()
Example #3
0
    def __init__(self, split: str = 'train') -> None:
        path = cached_download(IMDB_URL)
        tf = tarfile.open(path, 'r')
        cache_dir = Path(get_cache_root())
        if not all((cache_dir / p).exists() for p in ALL):
            print(f'Extracting from {path}...')
            tf.extractall(cache_dir)

        if split == 'train':
            pos_dir = f'{cache_dir / TRAIN_DIR}/pos'
            neg_dir = f'{cache_dir / TRAIN_DIR}/neg'
        elif split == 'test':
            pos_dir = f'{cache_dir / TEST_DIR}/pos'
            neg_dir = f'{cache_dir / TEST_DIR}/neg'
        else:
            raise ValueError(
                f"only 'train' and 'test' are valid for 'split', but '{split}' is given."
            )

        path = list(
            chain(Path(pos_dir).glob('*.txt'),
                  Path(neg_dir).glob('*.txt')))

        def map_func(x: Path) -> Tuple[str, int]:
            string = x.read_text()
            label = 0 if 'pos' in str(x) else 1
            return (string, label)

        super().__init__(path, map_func)
Example #4
0
 def setUpClass(cls):
     cls.default_cache_root = download.get_cache_root()
     cls.temp_dir = tempfile.mkdtemp()
     download.set_cache_root(cls.temp_dir)
Example #5
0
 def setUp(self):
     self.default_cache_root = download.get_cache_root()
     self.temp_dir = tempfile.mkdtemp()
     download.set_cache_root(self.temp_dir)
Example #6
0
 def setUp(self):
     self.default_cache_root = download.get_cache_root()
     self.temp_file_desc, self.temp_file_name = tempfile.mkstemp()
     download.set_cache_root(self.temp_file_name)
     self.dir_path = tempfile.mkdtemp()
Example #7
0
 def test_get_cache_directory(self):
     root = download.get_cache_root()
     path = download.get_cache_directory('test', False)
     self.assertEqual(path, os.path.join(root, 'test'))