示例#1
0
    def testBestCheckpointsWithNan(self):
        """
        Tests that checkpoints with nan priority are handled correctly.
        """
        keep_checkpoints_num = 2
        checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
        checkpoints = [
            _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, None,
                            self.mock_result(float("nan"), i))
            for i in range(2)
        ]
        checkpoints += [
            _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, 3,
                            self.mock_result(0, 3))
        ]
        random.shuffle(checkpoints)

        for checkpoint in checkpoints:
            checkpoint_manager.on_checkpoint(checkpoint)

        best_checkpoints = checkpoint_manager.best_checkpoints()
        # best_checkpoints is sorted from worst to best
        self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
        self.assertEqual(best_checkpoints[0].value, None)
        self.assertEqual(best_checkpoints[1].value, 3)
示例#2
0
    def testSameCheckpoint(self):
        checkpoint_manager = CheckpointManager(
            1, "i", delete_fn=lambda c: os.remove(c.value)
        )

        tmpfiles = []
        for i in range(3):
            _, tmpfile = tempfile.mkstemp()
            with open(tmpfile, "wt") as fp:
                fp.write("")
            tmpfiles.append(tmpfile)

        checkpoints = [
            _TuneCheckpoint(
                _TuneCheckpoint.PERSISTENT, tmpfiles[0], self.mock_result(5)
            ),
            _TuneCheckpoint(
                _TuneCheckpoint.PERSISTENT, tmpfiles[1], self.mock_result(10)
            ),
            _TuneCheckpoint(
                _TuneCheckpoint.PERSISTENT, tmpfiles[2], self.mock_result(0)
            ),
            _TuneCheckpoint(
                _TuneCheckpoint.PERSISTENT, tmpfiles[1], self.mock_result(20)
            ),
        ]
        for checkpoint in checkpoints:
            checkpoint_manager.on_checkpoint(checkpoint)
            self.assertTrue(os.path.exists(checkpoint.value))

        for tmpfile in tmpfiles:
            if os.path.exists(tmpfile):
                os.remove(tmpfile)
示例#3
0
 def testNewestCheckpoint(self):
     checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1)
     memory_checkpoint = _TuneCheckpoint(_TuneCheckpoint.MEMORY, {0},
                                         self.mock_result(0, 0))
     checkpoint_manager.on_checkpoint(memory_checkpoint)
     persistent_checkpoint = _TuneCheckpoint(_TuneCheckpoint.PERSISTENT,
                                             {1}, self.mock_result(1, 1))
     checkpoint_manager.on_checkpoint(persistent_checkpoint)
     self.assertEqual(checkpoint_manager.newest_persistent_checkpoint,
                      persistent_checkpoint)
示例#4
0
    def testOnMemoryCheckpoint(self):
        checkpoints = [
            _TuneCheckpoint(_TuneCheckpoint.MEMORY, 0, self.mock_result(0)),
            _TuneCheckpoint(_TuneCheckpoint.MEMORY, 0, self.mock_result(0)),
        ]
        checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1)
        checkpoint_manager.on_checkpoint(checkpoints[0])
        checkpoint_manager.on_checkpoint(checkpoints[1])
        newest = checkpoint_manager.newest_memory_checkpoint

        self.assertEqual(newest, checkpoints[1])
        self.assertEqual(checkpoint_manager.best_checkpoints(), [])
示例#5
0
    def testOnCheckpointUnordered(self):
        """
        Tests priorities that aren't inserted in ascending order. Also tests
        that the worst checkpoints are deleted when necessary.
        """
        keep_checkpoints_num = 2
        checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
        checkpoints = [
            _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, {i}, self.mock_result(i))
            for i in range(3, -1, -1)
        ]

        with patch.object(checkpoint_manager, "delete") as delete_mock:
            for j in range(0, len(checkpoints)):
                checkpoint_manager.on_checkpoint(checkpoints[j])
                expected_deletes = 0 if j != 3 else 1
                self.assertEqual(delete_mock.call_count, expected_deletes)
                self.assertEqual(
                    checkpoint_manager.newest_persistent_checkpoint, checkpoints[j]
                )

        best_checkpoints = checkpoint_manager.best_checkpoints()
        self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
        self.assertIn(checkpoints[0], best_checkpoints)
        self.assertIn(checkpoints[1], best_checkpoints)
示例#6
0
    def checkpoint(self):
        """Returns the most recent checkpoint.

        If the trial is in ERROR state, the most recent PERSISTENT checkpoint
        is returned.
        """
        if self.status == Trial.ERROR:
            checkpoint = self.checkpoint_manager.newest_persistent_checkpoint
        else:
            checkpoint = self.checkpoint_manager.newest_checkpoint
        if checkpoint.value is None:
            checkpoint = _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, self.restore_path)
        return checkpoint
示例#7
0
        def write_checkpoint(trial: Trial, index: int):
            checkpoint_dir = TrainableUtil.make_checkpoint_dir(trial.logdir,
                                                               index=index)
            result = {"training_iteration": index}
            with open(os.path.join(checkpoint_dir, "cp.json"), "w") as f:
                json.dump(result, f)

            tune_cp = _TuneCheckpoint(_TuneCheckpoint.PERSISTENT,
                                      checkpoint_dir, result)
            trial.saving_to = tune_cp
            trial.on_checkpoint(tune_cp)

            return checkpoint_dir
示例#8
0
    def testOnCheckpointUnavailableAttribute(self):
        """
        Tests that an error is logged when the associated result of the
        checkpoint has no checkpoint score attribute.
        """
        checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1)

        no_attr_checkpoint = _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, 0, {})
        with patch.object(logger, "error") as log_error_mock:
            checkpoint_manager.on_checkpoint(no_attr_checkpoint)
            log_error_mock.assert_called_once()
            # The newest checkpoint should still be set despite this error.
            self.assertEqual(checkpoint_manager.newest_persistent_checkpoint,
                             no_attr_checkpoint)
示例#9
0
    def testBestCheckpoints(self):
        """
        Tests that the best checkpoints are tracked and ordered correctly.
        """
        keep_checkpoints_num = 4
        checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
        checkpoints = [
            _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, i, self.mock_result(i))
            for i in range(16)
        ]
        random.shuffle(checkpoints)

        for checkpoint in checkpoints:
            checkpoint_manager.on_checkpoint(checkpoint)

        best_checkpoints = checkpoint_manager.best_checkpoints()
        self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
        for i in range(len(best_checkpoints)):
            self.assertEqual(best_checkpoints[i].value, i + 12)
示例#10
0
    def testBestCheckpointsOnlyNan(self):
        """
        Tests that checkpoints with only nan priority are handled correctly.
        """
        keep_checkpoints_num = 2
        checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
        checkpoints = [
            _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, i,
                            self.mock_result(float("nan"), i))
            for i in range(4)
        ]

        for checkpoint in checkpoints:
            checkpoint_manager.on_checkpoint(checkpoint)

        best_checkpoints = checkpoint_manager.best_checkpoints()
        # best_checkpoints is sorted from worst to best
        self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
        self.assertEqual(best_checkpoints[0].value, 2)
        self.assertEqual(best_checkpoints[1].value, 3)
示例#11
0
    def testBestCheckpoints(self):
        """
        Tests that the best checkpoints are tracked and ordered correctly.
        """
        keep_checkpoints_num = 4
        checkpoints = [
            _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, i,
                            self.mock_result(i, i)) for i in range(8)
        ]

        for permutation in itertools.permutations(checkpoints):
            checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)

            for checkpoint in permutation:
                checkpoint_manager.on_checkpoint(checkpoint)

            best_checkpoints = checkpoint_manager.best_checkpoints()
            self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
            for i in range(len(best_checkpoints)):
                self.assertEqual(best_checkpoints[i].value, i + 4)
示例#12
0
    def testCallbackSteps(self):
        trials = [
            Trial("__fake", trial_id="one"),
            Trial("__fake", trial_id="two")
        ]
        for t in trials:
            self.trial_runner.add_trial(t)

        self.executor.next_future_result = _ExecutorEvent(
            event_type=_ExecutorEventType.PG_READY)
        self.trial_runner.step()

        # Trial 1 has been started
        self.assertEqual(self.callback.state["trial_start"]["iteration"], 0)
        self.assertEqual(self.callback.state["trial_start"]["trial"].trial_id,
                         "one")

        # All these events haven't happened, yet
        self.assertTrue(
            all(k not in self.callback.state for k in [
                "trial_restore",
                "trial_save",
                "trial_result",
                "trial_complete",
                "trial_fail",
                "experiment_end",
            ]))

        self.executor.next_future_result = _ExecutorEvent(
            event_type=_ExecutorEventType.PG_READY)
        self.trial_runner.step()

        # Iteration not increased yet
        self.assertEqual(self.callback.state["step_begin"]["iteration"], 1)

        # Iteration increased
        self.assertEqual(self.callback.state["step_end"]["iteration"], 2)

        # Second trial has been just started
        self.assertEqual(self.callback.state["trial_start"]["iteration"], 1)
        self.assertEqual(self.callback.state["trial_start"]["trial"].trial_id,
                         "two")

        # Just a placeholder object ref for cp.value.
        cp = _TuneCheckpoint(_TuneCheckpoint.PERSISTENT,
                             value=ray.put(1),
                             result={TRAINING_ITERATION: 0})
        trials[0].saving_to = cp

        # Let the first trial save a checkpoint
        self.executor.next_future_result = _ExecutorEvent(
            event_type=_ExecutorEventType.SAVING_RESULT,
            trial=trials[0],
            result={_ExecutorEvent.KEY_FUTURE_RESULT: "__checkpoint"},
        )
        self.trial_runner.step()
        self.assertEqual(self.callback.state["trial_save"]["iteration"], 2)
        self.assertEqual(self.callback.state["trial_save"]["trial"].trial_id,
                         "one")

        # Let the second trial send a result
        result = {TRAINING_ITERATION: 1, "metric": 800, "done": False}
        self.executor.next_future_result = _ExecutorEvent(
            event_type=_ExecutorEventType.TRAINING_RESULT,
            trial=trials[1],
            result={"future_result": result},
        )
        self.assertTrue(not trials[1].has_reported_at_least_once)
        self.trial_runner.step()
        self.assertEqual(self.callback.state["trial_result"]["iteration"], 3)
        self.assertEqual(self.callback.state["trial_result"]["trial"].trial_id,
                         "two")
        self.assertEqual(
            self.callback.state["trial_result"]["result"]["metric"], 800)
        self.assertEqual(trials[1].last_result["metric"], 800)

        # Let the second trial restore from a checkpoint
        trials[1].restoring_from = cp
        self.executor.next_future_result = _ExecutorEvent(
            event_type=_ExecutorEventType.RESTORING_RESULT, trial=trials[1])
        self.trial_runner.step()
        self.assertEqual(self.callback.state["trial_restore"]["iteration"], 4)
        self.assertEqual(
            self.callback.state["trial_restore"]["trial"].trial_id, "two")

        # Let the second trial finish
        trials[1].restoring_from = None
        self.executor.next_future_result = _ExecutorEvent(
            event_type=_ExecutorEventType.TRAINING_RESULT,
            trial=trials[1],
            result={
                _ExecutorEvent.KEY_FUTURE_RESULT: {
                    TRAINING_ITERATION: 2,
                    "metric": 900,
                    "done": True,
                }
            },
        )
        self.trial_runner.step()
        self.assertEqual(self.callback.state["trial_complete"]["iteration"], 5)
        self.assertEqual(
            self.callback.state["trial_complete"]["trial"].trial_id, "two")

        # Let the first trial error
        self.executor.next_future_result = _ExecutorEvent(
            event_type=_ExecutorEventType.ERROR,
            trial=trials[0],
            result={_ExecutorEvent.KEY_EXCEPTION: Exception()},
        )
        self.trial_runner.step()
        self.assertEqual(self.callback.state["trial_fail"]["iteration"], 6)
        self.assertEqual(self.callback.state["trial_fail"]["trial"].trial_id,
                         "one")