예제 #1
0
    def test_raises_type_error_with_version(self, version):
        root_dir = self.create_tempdir()
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix='a_')

        with self.assertRaises(TypeError):
            program_state_mngr._get_path_for_version(version)
예제 #2
0
    def test_returns_version_with_path(self, path, expected_version):
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir='/tmp', prefix='a_')

        actual_version = program_state_mngr._get_version_for_path(path)

        self.assertEqual(actual_version, expected_version)
예제 #3
0
    async def test_raises_type_error_with_version(self, version):
        root_dir = self.create_tempdir()
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix='a_')

        with self.assertRaises(TypeError):
            await program_state_mngr.save('state', version)
예제 #4
0
    async def test_writes_program_state(self, program_state, expected_value):
        root_dir = self.create_tempdir()
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix='a_', keep_total=0)

        with mock.patch.object(file_utils,
                               'write_saved_model') as mock_write_saved_model:
            await program_state_mngr.save(program_state, 1)

            mock_write_saved_model.assert_called_once()
            call = mock_write_saved_model.mock_calls[0]
            _, args, kwargs = call
            actual_value, actual_path = args

            def _normalize(value: Any) -> Any:
                if isinstance(value, tf.data.Dataset):
                    return list(value)
                return value

            actual_value = tree.map_structure(_normalize, actual_value)
            expected_value = tree.map_structure(_normalize, expected_value)
            self.assertAllEqual(actual_value, expected_value)
            expected_path = os.path.join(root_dir, 'a_1')
            self.assertEqual(actual_path, expected_path)
            self.assertEqual(kwargs, {})
예제 #5
0
    def test_returns_path_with_root_dir_and_prefix(self, root_dir, prefix,
                                                   version, expected_path):
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix=prefix)

        actual_path = program_state_mngr._get_path_for_version(version)

        self.assertEqual(actual_path, expected_path)
예제 #6
0
    def test_returns_none_with_path(self, path):
        root_dir = self.create_tempdir()
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix='a_')

        version = program_state_mngr._get_version_for_path(path)

        self.assertIsNone(version)
예제 #7
0
    def test_creates_new_dir_with_root_dir_path_like(self):
        root_dir = self.create_tempdir()
        shutil.rmtree(root_dir)
        self.assertFalse(os.path.exists(root_dir))

        file_program_state_manager.FileProgramStateManager(root_dir=root_dir)

        self.assertTrue(os.path.exists(root_dir))
예제 #8
0
    async def test_returns_none_with_no_files(self):
        root_dir = self.create_tempdir()
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix='a_')

        versions = await program_state_mngr.get_versions()

        self.assertIsNone(versions)
예제 #9
0
    async def test_raises_version_not_found_error_with_no_saved_program_state(
            self):
        root_dir = self.create_tempdir()
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix='a_')

        with self.assertRaises(
                program_state_manager.ProgramStateManagerStateNotFoundError):
            await program_state_mngr.load(0, None)
예제 #10
0
    async def test_removes_saved_program_state_last(self):
        root_dir = self.create_tempdir()
        os.mkdir(os.path.join(root_dir, 'a_1'))
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix='a_')

        await program_state_mngr._remove(1)

        self.assertCountEqual(os.listdir(root_dir), [])
예제 #11
0
    async def test_noops_with_unknown_version(self):
        root_dir = self.create_tempdir()
        os.mkdir(os.path.join(root_dir, 'a_1'))
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix='a_')

        await program_state_mngr._remove(10)

        self.assertCountEqual(os.listdir(root_dir), ['a_1'])
예제 #12
0
    async def test_returns_none_if_root_dir_does_not_exist(self):
        root_dir = self.create_tempdir()
        shutil.rmtree(root_dir)
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix='a_')

        versions = await program_state_mngr.get_versions()

        self.assertIsNone(versions)
예제 #13
0
    async def test_removes_saved_program_state_with_keep_first_false(self):
        root_dir = self.create_tempdir()
        for version in range(10):
            os.mkdir(os.path.join(root_dir, f'a_{version}'))
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix='a_', keep_total=3, keep_first=False)

        await program_state_mngr._remove_old_program_state()

        self.assertCountEqual(os.listdir(root_dir), ['a_7', 'a_8', 'a_9'])
예제 #14
0
    async def test_returns_none_with_no_saved_program_state(self):
        root_dir = self.create_tempdir()
        for _ in range(10):
            tempfile.mkstemp(prefix=os.path.join(root_dir, 'a_'))
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix='a_')

        versions = await program_state_mngr.get_versions()

        self.assertIsNone(versions)
예제 #15
0
    async def test_raises_structure_error(self):
        root_dir = self.create_tempdir()
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix='a_')
        await program_state_mngr.save('state_1', 1)
        structure = []

        with self.assertRaises(
                program_state_manager.ProgramStateManagerStructureError):
            await program_state_mngr.load(1, structure)
예제 #16
0
    async def test_raises_version_not_found_error_with_unknown_version(self):
        root_dir = self.create_tempdir()
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix='a_')
        await program_state_mngr.save('state_1', 1)
        structure = 'state'

        with self.assertRaises(
                program_state_manager.ProgramStateManagerStateNotFoundError):
            await program_state_mngr.load(10, structure)
예제 #17
0
    async def test_returns_versions_with_saved_program_state(self, count):
        root_dir = self.create_tempdir()
        for version in range(count):
            os.mkdir(os.path.join(root_dir, f'a_{version}'))
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix='a_', keep_total=0)

        actual_versions = await program_state_mngr.get_versions()

        expected_versions = list(range(count))
        self.assertEqual(actual_versions, expected_versions)
예제 #18
0
    async def test_does_not_remove_saved_program_state_with_keep_total_0(self):
        root_dir = self.create_tempdir()
        for version in range(10):
            os.mkdir(os.path.join(root_dir, f'a_{version}'))
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix='a_', keep_total=0)

        await program_state_mngr._remove_old_program_state()

        self.assertCountEqual(os.listdir(root_dir),
                              [f'a_{i}' for i in range(10)])
예제 #19
0
    async def test_removes_saved_program_state(self):
        root_dir = self.create_tempdir()
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix='a_')

        with mock.patch.object(
                program_state_mngr,
                '_remove_old_program_state') as mock_remove_old_program_state:
            await program_state_mngr.save('state_1', 1)

            mock_remove_old_program_state.assert_called_once()
예제 #20
0
    async def test_raises_version_already_exists_error_with_existing_version(
            self):
        root_dir = self.create_tempdir()
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix='a_')

        await program_state_mngr.save('state_1', 1)

        with self.assertRaises(program_state_manager.
                               ProgramStateManagerStateAlreadyExistsError):
            await program_state_mngr.save('state_1', 1)
예제 #21
0
    async def test_removes_saved_program_state_with_version(self, version):
        root_dir = self.create_tempdir()
        for version in range(3):
            os.mkdir(os.path.join(root_dir, f'a_{version}'))
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix='a_')

        await program_state_mngr._remove(version)

        expected_dirs = ['a_0', 'a_1', 'a_2']
        expected_dirs.remove(f'a_{version}')
        self.assertCountEqual(os.listdir(root_dir), expected_dirs)
예제 #22
0
    async def test_returns_saved_program_state_with_version(self, version):
        root_dir = self.create_tempdir()
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix='a_')
        for i in range(3):
            await program_state_mngr.save(f'state_{i}', i)
        structure = 'state'

        actual_program_state = await program_state_mngr.load(
            version, structure)

        expected_program_state = f'state_{version}'
        self.assertEqual(actual_program_state, expected_program_state)
예제 #23
0
    async def test_returns_saved_program_state(self, program_state,
                                               expected_program_state):
        root_dir = self.create_tempdir()
        program_state_mngr = file_program_state_manager.FileProgramStateManager(
            root_dir=root_dir, prefix='a_')
        await program_state_mngr.save(program_state, 1)
        structure = program_state

        actual_program_state = await program_state_mngr.load(1, structure)

        if isinstance(actual_program_state, tf.data.Dataset):
            actual_program_state = list(actual_program_state)
        if isinstance(expected_program_state, tf.data.Dataset):
            expected_program_state = list(expected_program_state)
        self.assertAllEqual(actual_program_state, expected_program_state)
예제 #24
0
    def test_raises_type_error_with_keep_total(self, keep_total):
        root_dir = self.create_tempdir()

        with self.assertRaises(TypeError):
            file_program_state_manager.FileProgramStateManager(
                root_dir=root_dir, keep_total=keep_total)
예제 #25
0
    def test_raises_type_error_with_prefix(self, prefix):
        root_dir = self.create_tempdir()

        with self.assertRaises(TypeError):
            file_program_state_manager.FileProgramStateManager(
                root_dir=root_dir, prefix=prefix)
예제 #26
0
 def test_raises_value_error_with_root_dir_empty(self):
     with self.assertRaises(ValueError):
         file_program_state_manager.FileProgramStateManager(root_dir='')
예제 #27
0
 def test_raises_type_error_with_root_dir(self, root_dir):
     with self.assertRaises(TypeError):
         file_program_state_manager.FileProgramStateManager(
             root_dir=root_dir)