def test_set_cache_root(self): orig_root = download.get_cache_root() new_root = '/tmp/cache' try: download.set_cache_root(new_root) self.assertEqual(download.get_cache_root(), new_root) finally: download.set_cache_root(orig_root)
def setUpClass(cls): cls.default_cache_root = download.get_cache_root() cls.temp_dir = tempfile.mkdtemp() download.set_cache_root(cls.temp_dir) cls.patcher = mock.patch('lineflow.datasets.text_classification.sys.maxsize', int(sys.float_info.max)) cls.patcher.start()
def __init__(self, split: str = 'train') -> None: path = cached_download(IMDB_URL) tf = tarfile.open(path, 'r') cache_dir = Path(get_cache_root()) if not all((cache_dir / p).exists() for p in ALL): print(f'Extracting from {path}...') tf.extractall(cache_dir) if split == 'train': pos_dir = f'{cache_dir / TRAIN_DIR}/pos' neg_dir = f'{cache_dir / TRAIN_DIR}/neg' elif split == 'test': pos_dir = f'{cache_dir / TEST_DIR}/pos' neg_dir = f'{cache_dir / TEST_DIR}/neg' else: raise ValueError( f"only 'train' and 'test' are valid for 'split', but '{split}' is given." ) path = list( chain(Path(pos_dir).glob('*.txt'), Path(neg_dir).glob('*.txt'))) def map_func(x: Path) -> Tuple[str, int]: string = x.read_text() label = 0 if 'pos' in str(x) else 1 return (string, label) super().__init__(path, map_func)
def setUpClass(cls): cls.default_cache_root = download.get_cache_root() cls.temp_dir = tempfile.mkdtemp() download.set_cache_root(cls.temp_dir)
def setUp(self): self.default_cache_root = download.get_cache_root() self.temp_dir = tempfile.mkdtemp() download.set_cache_root(self.temp_dir)
def setUp(self): self.default_cache_root = download.get_cache_root() self.temp_file_desc, self.temp_file_name = tempfile.mkstemp() download.set_cache_root(self.temp_file_name) self.dir_path = tempfile.mkdtemp()
def test_get_cache_directory(self): root = download.get_cache_root() path = download.get_cache_directory('test', False) self.assertEqual(path, os.path.join(root, 'test'))