def testBestCheckpointsWithNan(self): """ Tests that checkpoints with nan priority are handled correctly. """ keep_checkpoints_num = 2 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, None, self.mock_result(float("nan"), i)) for i in range(2) ] checkpoints += [ _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, 3, self.mock_result(0, 3)) ] random.shuffle(checkpoints) for checkpoint in checkpoints: checkpoint_manager.on_checkpoint(checkpoint) best_checkpoints = checkpoint_manager.best_checkpoints() # best_checkpoints is sorted from worst to best self.assertEqual(len(best_checkpoints), keep_checkpoints_num) self.assertEqual(best_checkpoints[0].value, None) self.assertEqual(best_checkpoints[1].value, 3)
def testSameCheckpoint(self): checkpoint_manager = CheckpointManager( 1, "i", delete_fn=lambda c: os.remove(c.value) ) tmpfiles = [] for i in range(3): _, tmpfile = tempfile.mkstemp() with open(tmpfile, "wt") as fp: fp.write("") tmpfiles.append(tmpfile) checkpoints = [ _TuneCheckpoint( _TuneCheckpoint.PERSISTENT, tmpfiles[0], self.mock_result(5) ), _TuneCheckpoint( _TuneCheckpoint.PERSISTENT, tmpfiles[1], self.mock_result(10) ), _TuneCheckpoint( _TuneCheckpoint.PERSISTENT, tmpfiles[2], self.mock_result(0) ), _TuneCheckpoint( _TuneCheckpoint.PERSISTENT, tmpfiles[1], self.mock_result(20) ), ] for checkpoint in checkpoints: checkpoint_manager.on_checkpoint(checkpoint) self.assertTrue(os.path.exists(checkpoint.value)) for tmpfile in tmpfiles: if os.path.exists(tmpfile): os.remove(tmpfile)
def testNewestCheckpoint(self): checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) memory_checkpoint = _TuneCheckpoint(_TuneCheckpoint.MEMORY, {0}, self.mock_result(0, 0)) checkpoint_manager.on_checkpoint(memory_checkpoint) persistent_checkpoint = _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, {1}, self.mock_result(1, 1)) checkpoint_manager.on_checkpoint(persistent_checkpoint) self.assertEqual(checkpoint_manager.newest_persistent_checkpoint, persistent_checkpoint)
def testOnMemoryCheckpoint(self): checkpoints = [ _TuneCheckpoint(_TuneCheckpoint.MEMORY, 0, self.mock_result(0)), _TuneCheckpoint(_TuneCheckpoint.MEMORY, 0, self.mock_result(0)), ] checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) checkpoint_manager.on_checkpoint(checkpoints[0]) checkpoint_manager.on_checkpoint(checkpoints[1]) newest = checkpoint_manager.newest_memory_checkpoint self.assertEqual(newest, checkpoints[1]) self.assertEqual(checkpoint_manager.best_checkpoints(), [])
def testOnCheckpointUnordered(self): """ Tests priorities that aren't inserted in ascending order. Also tests that the worst checkpoints are deleted when necessary. """ keep_checkpoints_num = 2 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, {i}, self.mock_result(i)) for i in range(3, -1, -1) ] with patch.object(checkpoint_manager, "delete") as delete_mock: for j in range(0, len(checkpoints)): checkpoint_manager.on_checkpoint(checkpoints[j]) expected_deletes = 0 if j != 3 else 1 self.assertEqual(delete_mock.call_count, expected_deletes) self.assertEqual( checkpoint_manager.newest_persistent_checkpoint, checkpoints[j] ) best_checkpoints = checkpoint_manager.best_checkpoints() self.assertEqual(len(best_checkpoints), keep_checkpoints_num) self.assertIn(checkpoints[0], best_checkpoints) self.assertIn(checkpoints[1], best_checkpoints)
def checkpoint(self): """Returns the most recent checkpoint. If the trial is in ERROR state, the most recent PERSISTENT checkpoint is returned. """ if self.status == Trial.ERROR: checkpoint = self.checkpoint_manager.newest_persistent_checkpoint else: checkpoint = self.checkpoint_manager.newest_checkpoint if checkpoint.value is None: checkpoint = _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, self.restore_path) return checkpoint
def write_checkpoint(trial: Trial, index: int): checkpoint_dir = TrainableUtil.make_checkpoint_dir(trial.logdir, index=index) result = {"training_iteration": index} with open(os.path.join(checkpoint_dir, "cp.json"), "w") as f: json.dump(result, f) tune_cp = _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, checkpoint_dir, result) trial.saving_to = tune_cp trial.on_checkpoint(tune_cp) return checkpoint_dir
def testOnCheckpointUnavailableAttribute(self): """ Tests that an error is logged when the associated result of the checkpoint has no checkpoint score attribute. """ checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) no_attr_checkpoint = _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, 0, {}) with patch.object(logger, "error") as log_error_mock: checkpoint_manager.on_checkpoint(no_attr_checkpoint) log_error_mock.assert_called_once() # The newest checkpoint should still be set despite this error. self.assertEqual(checkpoint_manager.newest_persistent_checkpoint, no_attr_checkpoint)
def testBestCheckpoints(self): """ Tests that the best checkpoints are tracked and ordered correctly. """ keep_checkpoints_num = 4 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, i, self.mock_result(i)) for i in range(16) ] random.shuffle(checkpoints) for checkpoint in checkpoints: checkpoint_manager.on_checkpoint(checkpoint) best_checkpoints = checkpoint_manager.best_checkpoints() self.assertEqual(len(best_checkpoints), keep_checkpoints_num) for i in range(len(best_checkpoints)): self.assertEqual(best_checkpoints[i].value, i + 12)
def testBestCheckpointsOnlyNan(self): """ Tests that checkpoints with only nan priority are handled correctly. """ keep_checkpoints_num = 2 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, i, self.mock_result(float("nan"), i)) for i in range(4) ] for checkpoint in checkpoints: checkpoint_manager.on_checkpoint(checkpoint) best_checkpoints = checkpoint_manager.best_checkpoints() # best_checkpoints is sorted from worst to best self.assertEqual(len(best_checkpoints), keep_checkpoints_num) self.assertEqual(best_checkpoints[0].value, 2) self.assertEqual(best_checkpoints[1].value, 3)
def testBestCheckpoints(self): """ Tests that the best checkpoints are tracked and ordered correctly. """ keep_checkpoints_num = 4 checkpoints = [ _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, i, self.mock_result(i, i)) for i in range(8) ] for permutation in itertools.permutations(checkpoints): checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) for checkpoint in permutation: checkpoint_manager.on_checkpoint(checkpoint) best_checkpoints = checkpoint_manager.best_checkpoints() self.assertEqual(len(best_checkpoints), keep_checkpoints_num) for i in range(len(best_checkpoints)): self.assertEqual(best_checkpoints[i].value, i + 4)
def testCallbackSteps(self): trials = [ Trial("__fake", trial_id="one"), Trial("__fake", trial_id="two") ] for t in trials: self.trial_runner.add_trial(t) self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.PG_READY) self.trial_runner.step() # Trial 1 has been started self.assertEqual(self.callback.state["trial_start"]["iteration"], 0) self.assertEqual(self.callback.state["trial_start"]["trial"].trial_id, "one") # All these events haven't happened, yet self.assertTrue( all(k not in self.callback.state for k in [ "trial_restore", "trial_save", "trial_result", "trial_complete", "trial_fail", "experiment_end", ])) self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.PG_READY) self.trial_runner.step() # Iteration not increased yet self.assertEqual(self.callback.state["step_begin"]["iteration"], 1) # Iteration increased self.assertEqual(self.callback.state["step_end"]["iteration"], 2) # Second trial has been just started self.assertEqual(self.callback.state["trial_start"]["iteration"], 1) self.assertEqual(self.callback.state["trial_start"]["trial"].trial_id, "two") # Just a placeholder object ref for cp.value. cp = _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, value=ray.put(1), result={TRAINING_ITERATION: 0}) trials[0].saving_to = cp # Let the first trial save a checkpoint self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.SAVING_RESULT, trial=trials[0], result={_ExecutorEvent.KEY_FUTURE_RESULT: "__checkpoint"}, ) self.trial_runner.step() self.assertEqual(self.callback.state["trial_save"]["iteration"], 2) self.assertEqual(self.callback.state["trial_save"]["trial"].trial_id, "one") # Let the second trial send a result result = {TRAINING_ITERATION: 1, "metric": 800, "done": False} self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.TRAINING_RESULT, trial=trials[1], result={"future_result": result}, ) self.assertTrue(not trials[1].has_reported_at_least_once) self.trial_runner.step() self.assertEqual(self.callback.state["trial_result"]["iteration"], 3) self.assertEqual(self.callback.state["trial_result"]["trial"].trial_id, "two") self.assertEqual( self.callback.state["trial_result"]["result"]["metric"], 800) self.assertEqual(trials[1].last_result["metric"], 800) # Let the second trial restore from a checkpoint trials[1].restoring_from = cp self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.RESTORING_RESULT, trial=trials[1]) self.trial_runner.step() self.assertEqual(self.callback.state["trial_restore"]["iteration"], 4) self.assertEqual( self.callback.state["trial_restore"]["trial"].trial_id, "two") # Let the second trial finish trials[1].restoring_from = None self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.TRAINING_RESULT, trial=trials[1], result={ _ExecutorEvent.KEY_FUTURE_RESULT: { TRAINING_ITERATION: 2, "metric": 900, "done": True, } }, ) self.trial_runner.step() self.assertEqual(self.callback.state["trial_complete"]["iteration"], 5) self.assertEqual( self.callback.state["trial_complete"]["trial"].trial_id, "two") # Let the first trial error self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.ERROR, trial=trials[0], result={_ExecutorEvent.KEY_EXCEPTION: Exception()}, ) self.trial_runner.step() self.assertEqual(self.callback.state["trial_fail"]["iteration"], 6) self.assertEqual(self.callback.state["trial_fail"]["trial"].trial_id, "one")