def test_set_dataset_root(self): orig_root = dataset.get_dataset_root() new_root = '/tmp/dataset_root' try: dataset.set_dataset_root(new_root) self.assertEqual(dataset.get_dataset_root(), new_root) finally: dataset.set_dataset_root(orig_root)
def load_modelzipfile(url, param_name="hparams.json", model_name="model_last", root="~/.models"): tmp = dataset.get_dataset_root() dataset.set_dataset_root(Path(root).expanduser()) path = dataset.cached_download(url) dataset.set_dataset_root(tmp) with ZipFile(path) as f: namelist = [ name for name in f.namelist() if name[-1] != "/" if name[-1] != "/" ] param = [name for name in namelist if param_name == Path(name).name] assert len(param) == 1 param = param.pop() with TextIOWrapper(f.open(param, "r"), encoding="utf-8") as pf: param = pf.read() model = [name for name in namelist if model_name == Path(name).name] assert len(model) == 1 model = model.pop() model = BytesIO(f.read(model)) return param, model
def setUp(self): self.default_dataset_root = dataset.get_dataset_root() self.temp_file_desc, self.temp_file_name = tempfile.mkstemp() dataset.set_dataset_root(self.temp_file_name) self.dir_path = tempfile.mkdtemp()
def setUp(self): self.default_dataset_root = dataset.get_dataset_root() self.temp_dir = tempfile.mkdtemp() dataset.set_dataset_root(self.temp_dir)
def test_get_dataset_directory(self): root = dataset.get_dataset_root() path = dataset.get_dataset_directory('test', False) self.assertEqual(path, os.path.join(root, 'test'))