def testSameCheckpoint(self): checkpoint_manager = CheckpointManager( 1, "i", delete_fn=lambda c: os.remove(c.value)) tmpfiles = [] for i in range(3): _, tmpfile = tempfile.mkstemp() with open(tmpfile, "wt") as fp: fp.write("") tmpfiles.append(tmpfile) checkpoints = [ Checkpoint(Checkpoint.PERSISTENT, tmpfiles[0], self.mock_result(5)), Checkpoint(Checkpoint.PERSISTENT, tmpfiles[1], self.mock_result(10)), Checkpoint(Checkpoint.PERSISTENT, tmpfiles[2], self.mock_result(0)), Checkpoint(Checkpoint.PERSISTENT, tmpfiles[1], self.mock_result(20)) ] for checkpoint in checkpoints: checkpoint_manager.on_checkpoint(checkpoint) self.assertTrue(os.path.exists(checkpoint.value)) for tmpfile in tmpfiles: if os.path.exists(tmpfile): os.remove(tmpfile)
def testNewestCheckpoint(self): checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) memory_checkpoint = Checkpoint(Checkpoint.MEMORY, {0}, self.mock_result(0)) checkpoint_manager.on_checkpoint(memory_checkpoint) persistent_checkpoint = Checkpoint(Checkpoint.PERSISTENT, {1}, self.mock_result(1)) checkpoint_manager.on_checkpoint(persistent_checkpoint) self.assertEqual(checkpoint_manager.newest_persistent_checkpoint, persistent_checkpoint)
def testOnMemoryCheckpoint(self): checkpoints = [ Checkpoint(Checkpoint.MEMORY, 0, self.mock_result(0)), Checkpoint(Checkpoint.MEMORY, 0, self.mock_result(0)) ] checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) checkpoint_manager.on_checkpoint(checkpoints[0]) checkpoint_manager.on_checkpoint(checkpoints[1]) newest = checkpoint_manager.newest_memory_checkpoint self.assertEqual(newest, checkpoints[1]) self.assertEqual(checkpoint_manager.best_checkpoints(), [])
def testOnCheckpointUnordered(self): """ Tests priorities that aren't inserted in ascending order. Also tests that the worst checkpoints are deleted when necessary. """ keep_checkpoints_num = 2 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ Checkpoint(Checkpoint.PERSISTENT, {i}, self.mock_result(i)) for i in range(3, -1, -1) ] with patch.object(checkpoint_manager, "delete") as delete_mock: for j in range(0, len(checkpoints)): checkpoint_manager.on_checkpoint(checkpoints[j]) expected_deletes = 0 if j != 3 else 1 self.assertEqual(delete_mock.call_count, expected_deletes) self.assertEqual( checkpoint_manager.newest_persistent_checkpoint, checkpoints[j]) best_checkpoints = checkpoint_manager.best_checkpoints() self.assertEqual(len(best_checkpoints), keep_checkpoints_num) self.assertIn(checkpoints[0], best_checkpoints) self.assertIn(checkpoints[1], best_checkpoints)
def testOnCheckpointOrdered(self): """ Tests increasing priorities. Also tests that that the worst checkpoints are deleted when necessary. """ keep_checkpoints_num = 2 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ Checkpoint(Checkpoint.PERSISTENT, {i}, self.mock_result(i)) for i in range(3) ] with patch.object(checkpoint_manager, "delete") as \ delete_mock: for j in range(3): checkpoint_manager.on_checkpoint(checkpoints[j]) expected_deletes = 0 if j != 2 else 1 self.assertEqual(delete_mock.call_count, expected_deletes, j) self.assertEqual(checkpoint_manager.newest_checkpoint, checkpoints[j]) best_checkpoints = checkpoint_manager.best_checkpoints() self.assertEqual(len(best_checkpoints), keep_checkpoints_num) self.assertIn(checkpoints[1], best_checkpoints) self.assertIn(checkpoints[2], best_checkpoints)
def checkpoint(self): """Returns the most recent checkpoint. If the trial is in ERROR state, the most recent PERSISTENT checkpoint is returned. """ if self.status == Trial.ERROR: checkpoint = self.checkpoint_manager.newest_persistent_checkpoint else: checkpoint = self.checkpoint_manager.newest_checkpoint if checkpoint.value is None: checkpoint = Checkpoint(Checkpoint.PERSISTENT, self.restore_path) return checkpoint
def checkpoint(self): """Returns the most recent checkpoint. If the trial is PAUSED, this is the most recent MEMORY checkpoint. Otherwise, it is the most recent PERSISTENT checkpoint. """ if self.status == Trial.PAUSED: assert self.checkpoint_manager.newest_memory_checkpoint.value return self.checkpoint_manager.newest_memory_checkpoint checkpoint = self.checkpoint_manager.newest_persistent_checkpoint if checkpoint.value is None: checkpoint = Checkpoint(Checkpoint.PERSISTENT, self.restore_path) return checkpoint
def testOnCheckpointUnavailableAttribute(self): """ Tests that an error is logged when the associated result of the checkpoint has no checkpoint score attribute. """ keep_checkpoints_num = 1 checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i") no_attr_checkpoint = Checkpoint(Checkpoint.MEMORY, 0, {}) with patch.object(logger, "error") as log_error_mock: checkpoint_manager.on_checkpoint(no_attr_checkpoint) log_error_mock.assert_called_once() # The newest checkpoint should still be set despite this error. assert checkpoint_manager.newest_checkpoint == no_attr_checkpoint
def testOnCheckpointUnavailableAttribute(self): """ Tests that an error is logged when the associated result of the checkpoint has no checkpoint score attribute. """ checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) no_attr_checkpoint = Checkpoint(Checkpoint.PERSISTENT, 0, {}) with patch.object(logger, "error") as log_error_mock: checkpoint_manager.on_checkpoint(no_attr_checkpoint) log_error_mock.assert_called_once() # The newest checkpoint should still be set despite this error. self.assertEqual(checkpoint_manager.newest_persistent_checkpoint, no_attr_checkpoint)
def testBestCheckpoints(self): """ Tests that the best checkpoints are tracked and ordered correctly. """ keep_checkpoints_num = 4 checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i") checkpoints = [ Checkpoint(Checkpoint.MEMORY, i, self.mock_result(i)) for i in range(16) ] random.shuffle(checkpoints) for checkpoint in checkpoints: checkpoint_manager.on_checkpoint(checkpoint) best_checkpoints = checkpoint_manager.best_checkpoints() self.assertEqual(len(best_checkpoints), keep_checkpoints_num) for i in range(len(best_checkpoints)): self.assertEqual(best_checkpoints[i].value, i + 12)
def load_one_tune_analysis( checkpoints_paths: list, result: dict = { "training_iteration": 1, "episode_reward_mean": 1 }, default_metric: "str" = "episode_reward_mean", default_mode: str = "max", n_dir_level_between_ckpt_and_exp_state=1, ): """Helper to re-create a fake tune_analysis only containing the checkpoints provided.""" assert default_metric in result.keys() register_trainable("fake trial", Trainable) trials = [] for one_checkpoint_path in checkpoints_paths: one_trial = Trial(trainable_name="fake trial") ckpt = Checkpoint(Checkpoint.PERSISTENT, value=one_checkpoint_path, result=result) one_trial.checkpoint_manager.on_checkpoint(ckpt) trials.append(one_trial) json_file_path = _get_experiment_state_file_path( checkpoints_paths[0], split_path_n_times=n_dir_level_between_ckpt_and_exp_state, ) one_tune_analysis = ExperimentAnalysis( experiment_checkpoint_path=json_file_path, trials=trials, default_mode=default_mode, default_metric=default_metric, ) for trial in one_tune_analysis.trials: assert len(trial.checkpoint_manager.best_checkpoints()) == 1 return one_tune_analysis
def testOnCheckpointUnordered(self): """ Tests priorities that aren't inserted in ascending order. Also tests that the worst checkpoints are deleted when necessary. """ keep_checkpoints_num = 2 checkpoint_manager = CheckpointManager(keep_checkpoints_num, "i") checkpoints = [ Checkpoint(Checkpoint.DISK, {i}, self.mock_result(i)) for i in range(3, -1, -1) ] with patch("shutil.rmtree") as rmtree_mock, patch("os.path"): for j in range(0, len(checkpoints)): checkpoint_manager.on_checkpoint(checkpoints[j]) expected_deletes = 0 if j != 3 else 1 self.assertEqual(rmtree_mock.call_count, expected_deletes) self.assertEqual(checkpoint_manager.newest_checkpoint, checkpoints[j]) best_checkpoints = checkpoint_manager.best_checkpoints() self.assertEqual(len(best_checkpoints), keep_checkpoints_num) self.assertIn(checkpoints[0], best_checkpoints) self.assertIn(checkpoints[1], best_checkpoints)
def testCallbackSteps(self): trials = [ Trial("__fake", trial_id="one"), Trial("__fake", trial_id="two") ] for t in trials: self.trial_runner.add_trial(t) self.executor.next_trial = trials[0] self.trial_runner.step() # Trial 1 has been started self.assertEqual(self.callback.state["trial_start"]["iteration"], 0) self.assertEqual(self.callback.state["trial_start"]["trial"].trial_id, "one") # All these events haven't happened, yet self.assertTrue( all(k not in self.callback.state for k in [ "trial_restore", "trial_save", "trial_result", "trial_complete", "trial_fail" ])) self.executor.next_trial = trials[1] self.trial_runner.step() # Iteration not increased yet self.assertEqual(self.callback.state["step_begin"]["iteration"], 1) # Iteration increased self.assertEqual(self.callback.state["step_end"]["iteration"], 2) # Second trial has been just started self.assertEqual(self.callback.state["trial_start"]["iteration"], 1) self.assertEqual(self.callback.state["trial_start"]["trial"].trial_id, "two") cp = Checkpoint(Checkpoint.PERSISTENT, "__checkpoint", {TRAINING_ITERATION: 0}) # Let the first trial save a checkpoint self.executor.next_trial = trials[0] trials[0].saving_to = cp self.trial_runner.step() self.assertEqual(self.callback.state["trial_save"]["iteration"], 2) self.assertEqual(self.callback.state["trial_save"]["trial"].trial_id, "one") # Let the second trial send a result result = {TRAINING_ITERATION: 1, "metric": 800, "done": False} self.executor.results[trials[1]] = result self.executor.next_trial = trials[1] self.assertEqual(trials[1].last_result, {}) self.trial_runner.step() self.assertEqual(self.callback.state["trial_result"]["iteration"], 3) self.assertEqual(self.callback.state["trial_result"]["trial"].trial_id, "two") self.assertEqual( self.callback.state["trial_result"]["result"]["metric"], 800) self.assertEqual(trials[1].last_result["metric"], 800) # Let the second trial restore from a checkpoint trials[1].restoring_from = cp self.executor.results[trials[1]] = trials[1].last_result self.trial_runner.step() self.assertEqual(self.callback.state["trial_restore"]["iteration"], 4) self.assertEqual( self.callback.state["trial_restore"]["trial"].trial_id, "two") # Let the second trial finish trials[1].restoring_from = None self.executor.results[trials[1]] = { TRAINING_ITERATION: 2, "metric": 900, "done": True } self.trial_runner.step() self.assertEqual(self.callback.state["trial_complete"]["iteration"], 5) self.assertEqual( self.callback.state["trial_complete"]["trial"].trial_id, "two") # Let the first trial error self.executor.failed_trial = trials[0] self.trial_runner.step() self.assertEqual(self.callback.state["trial_fail"]["iteration"], 6) self.assertEqual(self.callback.state["trial_fail"]["trial"].trial_id, "one")
def __init__(self, trainable_name, config=None, trial_id=None, local_dir=DEFAULT_RESULTS_DIR, evaluated_params=None, experiment_tag="", resources=None, stopping_criterion=None, remote_checkpoint_dir=None, checkpoint_freq=0, checkpoint_at_end=False, sync_on_checkpoint=True, keep_checkpoints_num=None, checkpoint_score_attr=TRAINING_ITERATION, export_formats=None, restore_path=None, trial_name_creator=None, loggers=None, sync_to_driver_fn=None, max_failures=0): """Initialize a new trial. The args here take the same meaning as the command line flags defined in ray.tune.config_parser. """ validate_trainable(trainable_name) # Trial config self.trainable_name = trainable_name self.trial_id = Trial.generate_id() if trial_id is None else trial_id self.config = config or {} self.local_dir = local_dir # This remains unexpanded for syncing. #: Parameters that Tune varies across searches. self.evaluated_params = evaluated_params or {} self.experiment_tag = experiment_tag trainable_cls = self.get_trainable_cls() if trainable_cls and hasattr(trainable_cls, "default_resource_request"): default_resources = trainable_cls.default_resource_request( self.config) if default_resources: if resources: raise ValueError( "Resources for {} have been automatically set to {} " "by its `default_resource_request()` method. Please " "clear the `resources_per_trial` option.".format( trainable_cls, default_resources)) resources = default_resources self.location = Location() self.resources = resources or Resources(cpu=1, gpu=0) self.stopping_criterion = stopping_criterion or {} self.loggers = loggers self.sync_to_driver_fn = sync_to_driver_fn self.verbose = True self.max_failures = max_failures # Local trial state that is updated during the run self.last_result = {} self.last_update_time = -float("inf") # stores in memory max/min/last result for each metric by trial self.metric_analysis = {} self.export_formats = export_formats self.status = Trial.PENDING self.start_time = None self.logdir = None self.runner = None self.result_logger = None self.last_debug = 0 self.error_file = None self.error_msg = None self.custom_trial_name = None # Checkpointing fields if remote_checkpoint_dir: self.remote_checkpoint_dir_prefix = remote_checkpoint_dir else: self.remote_checkpoint_dir_prefix = None self.checkpoint_freq = checkpoint_freq self.checkpoint_at_end = checkpoint_at_end self.sync_on_checkpoint = sync_on_checkpoint newest_checkpoint = Checkpoint(Checkpoint.PERSISTENT, restore_path) self.checkpoint_manager = CheckpointManager( keep_checkpoints_num, checkpoint_score_attr, checkpoint_deleter(str(self), self.runner)) self.checkpoint_manager.newest_checkpoint = newest_checkpoint # Restoration fields self.restoring_from = None self.num_failures = 0 self.num_consecutive_start_attempts = 0 # AutoML fields self.results = None self.best_result = None self.param_config = None self.extra_arg = None self._nonjson_fields = [ "loggers", "sync_to_driver_fn", "results", "best_result", "param_config", "extra_arg", ] if trial_name_creator: self.custom_trial_name = trial_name_creator(self)
def testCallbackSteps(self): trials = [ Trial("__fake", trial_id="one"), Trial("__fake", trial_id="two") ] for t in trials: self.trial_runner.add_trial(t) self.executor.next_future_result = ExecutorEvent( event_type=ExecutorEventType.PG_READY) self.trial_runner.step() # Trial 1 has been started self.assertEqual(self.callback.state["trial_start"]["iteration"], 0) self.assertEqual(self.callback.state["trial_start"]["trial"].trial_id, "one") # All these events haven't happened, yet self.assertTrue( all(k not in self.callback.state for k in [ "trial_restore", "trial_save", "trial_result", "trial_complete", "trial_fail", "experiment_end", ])) self.executor.next_future_result = ExecutorEvent( event_type=ExecutorEventType.PG_READY) self.trial_runner.step() # Iteration not increased yet self.assertEqual(self.callback.state["step_begin"]["iteration"], 1) # Iteration increased self.assertEqual(self.callback.state["step_end"]["iteration"], 2) # Second trial has been just started self.assertEqual(self.callback.state["trial_start"]["iteration"], 1) self.assertEqual(self.callback.state["trial_start"]["trial"].trial_id, "two") cp = Checkpoint(Checkpoint.PERSISTENT, "__checkpoint", {TRAINING_ITERATION: 0}) # Let the first trial save a checkpoint self.executor.next_future_result = ExecutorEvent( event_type=ExecutorEventType.SAVING_RESULT, trial=trials[0]) trials[0].saving_to = cp self.trial_runner.step() self.assertEqual(self.callback.state["trial_save"]["iteration"], 2) self.assertEqual(self.callback.state["trial_save"]["trial"].trial_id, "one") # Let the second trial send a result result = {TRAINING_ITERATION: 1, "metric": 800, "done": False} self.executor.next_future_result = ExecutorEvent( event_type=ExecutorEventType.TRAINING_RESULT, trial=trials[1], result=result) self.assertTrue(not trials[1].has_reported_at_least_once) self.trial_runner.step() self.assertEqual(self.callback.state["trial_result"]["iteration"], 3) self.assertEqual(self.callback.state["trial_result"]["trial"].trial_id, "two") self.assertEqual( self.callback.state["trial_result"]["result"]["metric"], 800) self.assertEqual(trials[1].last_result["metric"], 800) # Let the second trial restore from a checkpoint trials[1].restoring_from = cp self.executor.next_future_result = ExecutorEvent( event_type=ExecutorEventType.RESTORING_RESULT, trial=trials[1]) self.trial_runner.step() self.assertEqual(self.callback.state["trial_restore"]["iteration"], 4) self.assertEqual( self.callback.state["trial_restore"]["trial"].trial_id, "two") # Let the second trial finish trials[1].restoring_from = None self.executor.next_future_result = ExecutorEvent( event_type=ExecutorEventType.TRAINING_RESULT, trial=trials[1], result={ TRAINING_ITERATION: 2, "metric": 900, "done": True }, ) self.trial_runner.step() self.assertEqual(self.callback.state["trial_complete"]["iteration"], 5) self.assertEqual( self.callback.state["trial_complete"]["trial"].trial_id, "two") # Let the first trial error self.executor.next_future_result = ExecutorEvent( event_type=ExecutorEventType.ERROR, trial=trials[0], result=(Exception(), "error"), ) self.trial_runner.step() self.assertEqual(self.callback.state["trial_fail"]["iteration"], 6) self.assertEqual(self.callback.state["trial_fail"]["trial"].trial_id, "one")