def test_basic_usage(self): # type: () -> None model = hub.load(self.name, self.repo) self.assertIsInstance(model, ModelProto) cached_files = list( glob.glob(join(hub.get_dir(), "**", "*.onnx"), recursive=True)) self.assertGreaterEqual(len(cached_files), 1)
def test_force_reload(self) -> None: model = hub.load(self.name, self.repo, force_reload=True) self.assertIsInstance(model, ModelProto) cached_files = list( glob.glob(join(hub.get_dir(), "**", "*.onnx"), recursive=True)) self.assertGreaterEqual(len(cached_files), 1)
def test_custom_cache(self): # type: () -> None old_cache = hub.get_dir() new_cache = join(old_cache, "custom") hub.set_dir(new_cache) model = hub.load(self.name, self.repo) self.assertIsInstance(model, ModelProto) cached_files = list( glob.glob(join(new_cache, "**", "*.onnx"), recursive=True)) self.assertGreaterEqual(len(cached_files), 1) hub.set_dir(old_cache)
def test_manifest_not_found(self): # type: () -> None self.assertRaises( AssertionError, lambda: hub.load(self.name, "onnx/models:unknown", silent=True))
def test_opset_error(self): # type: () -> None self.assertRaises(AssertionError, lambda: hub.load(self.name, self.repo, opset=-1))
def test_download_with_opset(self): # type: () -> None model = hub.load(self.name, self.repo, opset=12) self.assertIsInstance(model, ModelProto)