def load_modelzipfile(url, param_name="hparams.json", model_name="model_last", root="~/.models"): tmp = dataset.get_dataset_root() dataset.set_dataset_root(Path(root).expanduser()) path = dataset.cached_download(url) dataset.set_dataset_root(tmp) with ZipFile(path) as f: namelist = [ name for name in f.namelist() if name[-1] != "/" if name[-1] != "/" ] param = [name for name in namelist if param_name == Path(name).name] assert len(param) == 1 param = param.pop() with TextIOWrapper(f.open(param, "r"), encoding="utf-8") as pf: param = pf.read() model = [name for name in namelist if model_name == Path(name).name] assert len(model) == 1 model = model.pop() model = BytesIO(f.read(model)) return param, model
def test_cached_download(self): with mock.patch('six.moves.urllib.request.urlretrieve') as f: def download(url, path): with open(path, 'w') as f: f.write('test') f.side_effect = download cache_path = dataset.cached_download('http://example.com') self.assertEqual(f.call_count, 1) args, kwargs = f.call_args self.assertEqual(kwargs, {}) self.assertEqual(len(args), 2) # The second argument is a temporary path, and it is removed self.assertEqual(args[0], 'http://example.com') self.assertTrue(os.path.exists(cache_path)) with open(cache_path) as f: stored_data = f.read() self.assertEqual(stored_data, 'test')
def test_cached_download(self): with mock.patch('six.moves.urllib.request.urlretrieve') as f: def download(url, path): with open(path, 'w') as f: f.write('test') f.side_effect = download cache_path = dataset.cached_download('https://example.com') self.assertEqual(f.call_count, 1) args, kwargs = f.call_args self.assertEqual(kwargs, {}) self.assertEqual(len(args), 2) # The second argument is a temporary path, and it is removed self.assertEqual(args[0], 'https://example.com') self.assertTrue(os.path.exists(cache_path)) with open(cache_path) as f: stored_data = f.read() self.assertEqual(stored_data, 'test')
def test_file_exists(self): # Make an empty file which has the same name as the cache directory with open(os.path.join(self.temp_dir, '_dl_cache'), 'w'): pass with self.assertRaises(OSError): dataset.cached_download('http://example.com')
def test_fail_to_make_dir(self): with mock.patch('os.makedirs') as f: f.side_effect = OSError() with self.assertRaises(OSError): dataset.cached_download('http://example.com')
def test_file_exists(self): # Make an empty file which has the same name as the cache directory with open(os.path.join(self.temp_dir, '_dl_cache'), 'w'): pass with self.assertRaises(OSError): dataset.cached_download('http://example.com')
def test_fail_to_make_dir(self): with mock.patch('os.makedirs') as f: f.side_effect = OSError() with self.assertRaises(OSError): dataset.cached_download('http://example.com')