Пример #1
0
    def test_raises_file_not_found_error_with_no_checkpoint(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
        structure = _create_test_state()

        with self.assertRaises(FileNotFoundError):
            _ = checkpoint_mngr.load_checkpoint(structure, 0)
Пример #2
0
    def test_saves_one_checkpoint(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)

        test_state_1 = _create_test_state(1)
        checkpoint_mngr.save_checkpoint(test_state_1, 1)

        self.assertCountEqual(os.listdir(temp_dir), ['ckpt_1'])
Пример #3
0
    def test_raises_value_error_with_bad_structure(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
        test_state_1 = _create_test_state(1)
        checkpoint_mngr.save_checkpoint(test_state_1, 1)
        structure = None

        with self.assertRaises(ValueError):
            checkpoint_mngr.load_latest_checkpoint(structure)
Пример #4
0
    def test_returns_none_with_no_checkpoints(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
        structure = _create_test_state()

        state, round_num = checkpoint_mngr.load_latest_checkpoint(structure)

        self.assertIsNone(state)
        self.assertIsNone(round_num)
Пример #5
0
    def test_raises_already_exists_error_with_existing_round_number(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)

        test_state_1 = _create_test_state(1)
        checkpoint_mngr.save_checkpoint(test_state_1, 1)

        with self.assertRaises(tf.errors.AlreadyExistsError):
            checkpoint_mngr.save_checkpoint(test_state_1, 1)
Пример #6
0
    def test_returns_state_with_one_checkpoint(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
        test_state_1 = _create_test_state(1)
        checkpoint_mngr.save_checkpoint(test_state_1, 1)
        structure = _create_test_state()

        state = checkpoint_mngr.load_checkpoint(structure, 1)

        self.assertEqual(state, test_state_1)
Пример #7
0
    def test_saves_three_checkpoints(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)

        for i in range(1, 4):
            test_state = _create_test_state(i)
            checkpoint_mngr.save_checkpoint(test_state, i)

        self.assertCountEqual(os.listdir(temp_dir),
                              ['ckpt_1', 'ckpt_2', 'ckpt_3'])
Пример #8
0
    def test_save_with_nondefault_checkpoints_per_round(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(
            temp_dir, step=3, keep_total=3)

        for i in range(8):
            test_state = _create_test_state(i)
            checkpoint_mngr.save_checkpoint(test_state, i)

        self.assertCountEqual(os.listdir(temp_dir),
                              ['ckpt_0', 'ckpt_3', 'ckpt_6'])
Пример #9
0
    def test_removes_oldest_with_keep_first_false(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(
            temp_dir, keep_total=3, keep_first=False)

        for i in range(1, 5):
            test_state = _create_test_state(i)
            checkpoint_mngr.save_checkpoint(test_state, i)

        self.assertCountEqual(os.listdir(temp_dir),
                              ['ckpt_2', 'ckpt_3', 'ckpt_4'])
Пример #10
0
    def test_keep_all_checkpoints(self, keep_total):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(
            temp_dir, keep_total=keep_total, keep_first=False)

        for i in range(1, 4):
            test_state = _create_test_state(i)
            checkpoint_mngr.save_checkpoint(test_state, i)

        self.assertCountEqual(os.listdir(temp_dir),
                              ['ckpt_1', 'ckpt_2', 'ckpt_3'])
Пример #11
0
    def test_saves_three_checkpoints(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)

        test_state_1 = _create_test_state(1)
        checkpoint_mngr.save_checkpoint(test_state_1, 1)
        test_state_2 = _create_test_state(2)
        checkpoint_mngr.save_checkpoint(test_state_2, 2)
        test_state_3 = _create_test_state(3)
        checkpoint_mngr.save_checkpoint(test_state_3, 3)

        self.assertCountEqual(os.listdir(temp_dir),
                              ['ckpt_1', 'ckpt_2', 'ckpt_3'])
Пример #12
0
    def test_returns_state_with_three_checkpoint_for_first_round(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
        test_state_1 = _create_test_state(1)
        checkpoint_mngr.save_checkpoint(test_state_1, 1)
        test_state_2 = _create_test_state(2)
        checkpoint_mngr.save_checkpoint(test_state_2, 2)
        test_state_3 = _create_test_state(3)
        checkpoint_mngr.save_checkpoint(test_state_3, 3)
        structure = _create_test_state()

        state = checkpoint_mngr.load_checkpoint(structure, 1)

        self.assertEqual(state, test_state_1)
Пример #13
0
    def test_returns_state_and_round_num_with_three_checkpoints(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(temp_dir)
        test_state_1 = _create_test_state(1)
        checkpoint_mngr.save_checkpoint(test_state_1, 1)
        test_state_2 = _create_test_state(2)
        checkpoint_mngr.save_checkpoint(test_state_2, 2)
        test_state_3 = _create_test_state(3)
        checkpoint_mngr.save_checkpoint(test_state_3, 3)
        structure = _create_test_state()

        state, round_num = checkpoint_mngr.load_latest_checkpoint(structure)

        self.assertEqual(state, test_state_3)
        self.assertEqual(round_num, 3)
Пример #14
0
    def test_removes_oldest_with_keep_first_false(self):
        temp_dir = self.get_temp_dir()
        checkpoint_mngr = checkpoint_manager.FileCheckpointManager(
            temp_dir, keep_total=3, keep_first=False)

        test_state_1 = _create_test_state(1)
        checkpoint_mngr.save_checkpoint(test_state_1, 1)
        test_state_2 = _create_test_state(2)
        checkpoint_mngr.save_checkpoint(test_state_2, 2)
        test_state_3 = _create_test_state(3)
        checkpoint_mngr.save_checkpoint(test_state_3, 3)
        test_state_4 = _create_test_state(4)
        checkpoint_mngr.save_checkpoint(test_state_4, 4)

        self.assertCountEqual(os.listdir(temp_dir),
                              ['ckpt_2', 'ckpt_3', 'ckpt_4'])