def test_raises_type_error_with_version(self, version): root_dir = self.create_tempdir() program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix='a_') with self.assertRaises(TypeError): program_state_mngr._get_path_for_version(version)
def test_returns_version_with_path(self, path, expected_version): program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir='/tmp', prefix='a_') actual_version = program_state_mngr._get_version_for_path(path) self.assertEqual(actual_version, expected_version)
async def test_raises_type_error_with_version(self, version): root_dir = self.create_tempdir() program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix='a_') with self.assertRaises(TypeError): await program_state_mngr.save('state', version)
async def test_writes_program_state(self, program_state, expected_value): root_dir = self.create_tempdir() program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix='a_', keep_total=0) with mock.patch.object(file_utils, 'write_saved_model') as mock_write_saved_model: await program_state_mngr.save(program_state, 1) mock_write_saved_model.assert_called_once() call = mock_write_saved_model.mock_calls[0] _, args, kwargs = call actual_value, actual_path = args def _normalize(value: Any) -> Any: if isinstance(value, tf.data.Dataset): return list(value) return value actual_value = tree.map_structure(_normalize, actual_value) expected_value = tree.map_structure(_normalize, expected_value) self.assertAllEqual(actual_value, expected_value) expected_path = os.path.join(root_dir, 'a_1') self.assertEqual(actual_path, expected_path) self.assertEqual(kwargs, {})
def test_returns_path_with_root_dir_and_prefix(self, root_dir, prefix, version, expected_path): program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix=prefix) actual_path = program_state_mngr._get_path_for_version(version) self.assertEqual(actual_path, expected_path)
def test_returns_none_with_path(self, path): root_dir = self.create_tempdir() program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix='a_') version = program_state_mngr._get_version_for_path(path) self.assertIsNone(version)
def test_creates_new_dir_with_root_dir_path_like(self): root_dir = self.create_tempdir() shutil.rmtree(root_dir) self.assertFalse(os.path.exists(root_dir)) file_program_state_manager.FileProgramStateManager(root_dir=root_dir) self.assertTrue(os.path.exists(root_dir))
async def test_returns_none_with_no_files(self): root_dir = self.create_tempdir() program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix='a_') versions = await program_state_mngr.get_versions() self.assertIsNone(versions)
async def test_raises_version_not_found_error_with_no_saved_program_state( self): root_dir = self.create_tempdir() program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix='a_') with self.assertRaises( program_state_manager.ProgramStateManagerStateNotFoundError): await program_state_mngr.load(0, None)
async def test_removes_saved_program_state_last(self): root_dir = self.create_tempdir() os.mkdir(os.path.join(root_dir, 'a_1')) program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix='a_') await program_state_mngr._remove(1) self.assertCountEqual(os.listdir(root_dir), [])
async def test_noops_with_unknown_version(self): root_dir = self.create_tempdir() os.mkdir(os.path.join(root_dir, 'a_1')) program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix='a_') await program_state_mngr._remove(10) self.assertCountEqual(os.listdir(root_dir), ['a_1'])
async def test_returns_none_if_root_dir_does_not_exist(self): root_dir = self.create_tempdir() shutil.rmtree(root_dir) program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix='a_') versions = await program_state_mngr.get_versions() self.assertIsNone(versions)
async def test_removes_saved_program_state_with_keep_first_false(self): root_dir = self.create_tempdir() for version in range(10): os.mkdir(os.path.join(root_dir, f'a_{version}')) program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix='a_', keep_total=3, keep_first=False) await program_state_mngr._remove_old_program_state() self.assertCountEqual(os.listdir(root_dir), ['a_7', 'a_8', 'a_9'])
async def test_returns_none_with_no_saved_program_state(self): root_dir = self.create_tempdir() for _ in range(10): tempfile.mkstemp(prefix=os.path.join(root_dir, 'a_')) program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix='a_') versions = await program_state_mngr.get_versions() self.assertIsNone(versions)
async def test_raises_structure_error(self): root_dir = self.create_tempdir() program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix='a_') await program_state_mngr.save('state_1', 1) structure = [] with self.assertRaises( program_state_manager.ProgramStateManagerStructureError): await program_state_mngr.load(1, structure)
async def test_raises_version_not_found_error_with_unknown_version(self): root_dir = self.create_tempdir() program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix='a_') await program_state_mngr.save('state_1', 1) structure = 'state' with self.assertRaises( program_state_manager.ProgramStateManagerStateNotFoundError): await program_state_mngr.load(10, structure)
async def test_returns_versions_with_saved_program_state(self, count): root_dir = self.create_tempdir() for version in range(count): os.mkdir(os.path.join(root_dir, f'a_{version}')) program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix='a_', keep_total=0) actual_versions = await program_state_mngr.get_versions() expected_versions = list(range(count)) self.assertEqual(actual_versions, expected_versions)
async def test_does_not_remove_saved_program_state_with_keep_total_0(self): root_dir = self.create_tempdir() for version in range(10): os.mkdir(os.path.join(root_dir, f'a_{version}')) program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix='a_', keep_total=0) await program_state_mngr._remove_old_program_state() self.assertCountEqual(os.listdir(root_dir), [f'a_{i}' for i in range(10)])
async def test_removes_saved_program_state(self): root_dir = self.create_tempdir() program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix='a_') with mock.patch.object( program_state_mngr, '_remove_old_program_state') as mock_remove_old_program_state: await program_state_mngr.save('state_1', 1) mock_remove_old_program_state.assert_called_once()
async def test_raises_version_already_exists_error_with_existing_version( self): root_dir = self.create_tempdir() program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix='a_') await program_state_mngr.save('state_1', 1) with self.assertRaises(program_state_manager. ProgramStateManagerStateAlreadyExistsError): await program_state_mngr.save('state_1', 1)
async def test_removes_saved_program_state_with_version(self, version): root_dir = self.create_tempdir() for version in range(3): os.mkdir(os.path.join(root_dir, f'a_{version}')) program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix='a_') await program_state_mngr._remove(version) expected_dirs = ['a_0', 'a_1', 'a_2'] expected_dirs.remove(f'a_{version}') self.assertCountEqual(os.listdir(root_dir), expected_dirs)
async def test_returns_saved_program_state_with_version(self, version): root_dir = self.create_tempdir() program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix='a_') for i in range(3): await program_state_mngr.save(f'state_{i}', i) structure = 'state' actual_program_state = await program_state_mngr.load( version, structure) expected_program_state = f'state_{version}' self.assertEqual(actual_program_state, expected_program_state)
async def test_returns_saved_program_state(self, program_state, expected_program_state): root_dir = self.create_tempdir() program_state_mngr = file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix='a_') await program_state_mngr.save(program_state, 1) structure = program_state actual_program_state = await program_state_mngr.load(1, structure) if isinstance(actual_program_state, tf.data.Dataset): actual_program_state = list(actual_program_state) if isinstance(expected_program_state, tf.data.Dataset): expected_program_state = list(expected_program_state) self.assertAllEqual(actual_program_state, expected_program_state)
def test_raises_type_error_with_keep_total(self, keep_total): root_dir = self.create_tempdir() with self.assertRaises(TypeError): file_program_state_manager.FileProgramStateManager( root_dir=root_dir, keep_total=keep_total)
def test_raises_type_error_with_prefix(self, prefix): root_dir = self.create_tempdir() with self.assertRaises(TypeError): file_program_state_manager.FileProgramStateManager( root_dir=root_dir, prefix=prefix)
def test_raises_value_error_with_root_dir_empty(self): with self.assertRaises(ValueError): file_program_state_manager.FileProgramStateManager(root_dir='')
def test_raises_type_error_with_root_dir(self, root_dir): with self.assertRaises(TypeError): file_program_state_manager.FileProgramStateManager( root_dir=root_dir)