예제 #1
0
def load_modelzipfile(url,
                      param_name="hparams.json",
                      model_name="model_last",
                      root="~/.models"):

    tmp = dataset.get_dataset_root()
    dataset.set_dataset_root(Path(root).expanduser())
    path = dataset.cached_download(url)
    dataset.set_dataset_root(tmp)

    with ZipFile(path) as f:
        namelist = [
            name for name in f.namelist() if name[-1] != "/" if name[-1] != "/"
        ]

        param = [name for name in namelist if param_name == Path(name).name]
        assert len(param) == 1
        param = param.pop()
        with TextIOWrapper(f.open(param, "r"), encoding="utf-8") as pf:
            param = pf.read()

        model = [name for name in namelist if model_name == Path(name).name]
        assert len(model) == 1
        model = model.pop()
        model = BytesIO(f.read(model))

    return param, model
예제 #2
0
    def test_cached_download(self):
        with mock.patch('six.moves.urllib.request.urlretrieve') as f:
            def download(url, path):
                with open(path, 'w') as f:
                    f.write('test')
            f.side_effect = download

            cache_path = dataset.cached_download('http://example.com')

        self.assertEqual(f.call_count, 1)
        args, kwargs = f.call_args
        self.assertEqual(kwargs, {})
        self.assertEqual(len(args), 2)
        # The second argument is a temporary path, and it is removed
        self.assertEqual(args[0], 'http://example.com')

        self.assertTrue(os.path.exists(cache_path))
        with open(cache_path) as f:
            stored_data = f.read()
        self.assertEqual(stored_data, 'test')
예제 #3
0
    def test_cached_download(self):
        with mock.patch('six.moves.urllib.request.urlretrieve') as f:
            def download(url, path):
                with open(path, 'w') as f:
                    f.write('test')
            f.side_effect = download

            cache_path = dataset.cached_download('https://example.com')

        self.assertEqual(f.call_count, 1)
        args, kwargs = f.call_args
        self.assertEqual(kwargs, {})
        self.assertEqual(len(args), 2)
        # The second argument is a temporary path, and it is removed
        self.assertEqual(args[0], 'https://example.com')

        self.assertTrue(os.path.exists(cache_path))
        with open(cache_path) as f:
            stored_data = f.read()
        self.assertEqual(stored_data, 'test')
예제 #4
0
 def test_file_exists(self):
     # Make an empty file which has the same name as the cache directory
     with open(os.path.join(self.temp_dir, '_dl_cache'), 'w'):
         pass
     with self.assertRaises(OSError):
         dataset.cached_download('http://example.com')
예제 #5
0
 def test_fail_to_make_dir(self):
     with mock.patch('os.makedirs') as f:
         f.side_effect = OSError()
         with self.assertRaises(OSError):
             dataset.cached_download('http://example.com')
예제 #6
0
 def test_file_exists(self):
     # Make an empty file which has the same name as the cache directory
     with open(os.path.join(self.temp_dir, '_dl_cache'), 'w'):
         pass
     with self.assertRaises(OSError):
         dataset.cached_download('http://example.com')
예제 #7
0
 def test_fail_to_make_dir(self):
     with mock.patch('os.makedirs') as f:
         f.side_effect = OSError()
         with self.assertRaises(OSError):
             dataset.cached_download('http://example.com')