Exemplo n.º 1
0
    def test_basic_usage(self):  # type: () -> None
        model = hub.load(self.name, self.repo)
        self.assertIsInstance(model, ModelProto)

        cached_files = list(
            glob.glob(join(hub.get_dir(), "**", "*.onnx"), recursive=True))
        self.assertGreaterEqual(len(cached_files), 1)
Exemplo n.º 2
0
    def test_force_reload(self) -> None:
        model = hub.load(self.name, self.repo, force_reload=True)
        self.assertIsInstance(model, ModelProto)

        cached_files = list(
            glob.glob(join(hub.get_dir(), "**", "*.onnx"), recursive=True))
        self.assertGreaterEqual(len(cached_files), 1)
Exemplo n.º 3
0
    def test_custom_cache(self):  # type: () -> None
        old_cache = hub.get_dir()
        new_cache = join(old_cache, "custom")
        hub.set_dir(new_cache)

        model = hub.load(self.name, self.repo)
        self.assertIsInstance(model, ModelProto)

        cached_files = list(
            glob.glob(join(new_cache, "**", "*.onnx"), recursive=True))
        self.assertGreaterEqual(len(cached_files), 1)

        hub.set_dir(old_cache)