def testBestCheckpointsWithNan(self): """ Tests that checkpoints with nan priority are handled correctly. """ keep_checkpoints_num = 2 checkpoints = [ _TrackedCheckpoint( dir_or_data=None, storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(float("nan"), i), ) for i in range(2) ] + [ _TrackedCheckpoint( dir_or_data=3, storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(0, 3), ) ] for permutation in itertools.permutations(checkpoints): checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) for checkpoint in permutation: checkpoint_manager.on_checkpoint(checkpoint) best_checkpoints = checkpoint_manager.best_checkpoints() # best_checkpoints is sorted from worst to best self.assertEqual(len(best_checkpoints), keep_checkpoints_num) self.assertEqual(best_checkpoints[0].dir_or_data, None) self.assertEqual(best_checkpoints[1].dir_or_data, 3)
def testNewestCheckpoint(self): checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) memory_checkpoint = _TrackedCheckpoint( dir_or_data={"a": 0}, storage_mode=CheckpointStorage.MEMORY, metrics=self.mock_result(0, 0), ) checkpoint_manager.on_checkpoint(memory_checkpoint) persistent_checkpoint = _TrackedCheckpoint( dir_or_data={"a": 1}, storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(1, 1), ) checkpoint_manager.on_checkpoint(persistent_checkpoint) self.assertEqual(checkpoint_manager.newest_persistent_checkpoint, persistent_checkpoint)
def testOnCheckpointOrdered(self): """ Tests increasing priorities. Also tests that that the worst checkpoints are deleted when necessary. """ keep_checkpoints_num = 2 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ _TrackedCheckpoint( dir_or_data={i}, storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(i, i), ) for i in range(3) ] with patch.object(checkpoint_manager, "_delete_persisted_checkpoint") as delete_mock: for j in range(3): checkpoint_manager.on_checkpoint(checkpoints[j]) expected_deletes = 0 if j != 2 else 1 self.assertEqual(delete_mock.call_count, expected_deletes, j) self.assertEqual( checkpoint_manager.newest_persistent_checkpoint, checkpoints[j]) best_checkpoints = checkpoint_manager.best_checkpoints() self.assertEqual(len(best_checkpoints), keep_checkpoints_num) self.assertIn(checkpoints[1], best_checkpoints) self.assertIn(checkpoints[2], best_checkpoints)
def test_syncer_callback_noop_on_trial_cloud_checkpointing(): """Check that trial using cloud checkpointing disables sync to driver""" callbacks = create_default_callbacks(callbacks=[], sync_config=SyncConfig()) syncer_callback = None for cb in callbacks: if isinstance(cb, SyncerCallback): syncer_callback = cb trial1 = MockTrial(trial_id="a", logdir=None) trial1.uses_cloud_checkpointing = True assert syncer_callback assert syncer_callback._enabled # Cloud checkpointing set, so no-op assert not syncer_callback._sync_trial_dir(trial1) # This should not raise any error for not existing directory syncer_callback.on_checkpoint( iteration=1, trials=[], trial=trial1, checkpoint=_TrackedCheckpoint( dir_or_data="/does/not/exist", storage_mode=CheckpointStorage.PERSISTENT), )
def testOnCheckpointUnordered(self): """ Tests priorities that aren't inserted in ascending order. Also tests that the worst checkpoints are deleted when necessary. """ keep_checkpoints_num = 2 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ _TrackedCheckpoint( dir_or_data={i}, storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(i, i), ) for i in range(3, -1, -1) ] with patch.object(checkpoint_manager, "_delete_persisted_checkpoint") as delete_mock: for j in range(0, len(checkpoints)): checkpoint_manager.on_checkpoint(checkpoints[j]) expected_deletes = 0 if j != 3 else 1 self.assertEqual( delete_mock.call_count, expected_deletes, msg=f"Called {delete_mock.call_count} times", ) self.assertEqual( checkpoint_manager.newest_persistent_checkpoint, checkpoints[j]) best_checkpoints = checkpoint_manager.best_checkpoints() self.assertEqual(len(best_checkpoints), keep_checkpoints_num) self.assertIn(checkpoints[0], best_checkpoints) self.assertIn(checkpoints[1], best_checkpoints)
def test_result_grid_future_checkpoint(ray_start_2_cpus, to_object): trainable_cls = get_trainable_cls("__fake") trial = Trial("__fake", stub=True) trial.config = {"some_config": 1} trial.last_result = {"some_result": 2, "config": trial.config} trainable = ray.remote(trainable_cls).remote() ray.get(trainable.set_info.remote({"info": 4})) if to_object: checkpoint_data = trainable.save_to_object.remote() else: checkpoint_data = trainable.save.remote() trial.on_checkpoint( _TrackedCheckpoint(checkpoint_data, storage_mode=CheckpointStorage.MEMORY)) trial.pickled_error_file = None trial.error_file = None result_grid = ResultGrid(None) # Internal result grid conversion result = result_grid._trial_to_result(trial) assert isinstance(result.checkpoint, Checkpoint) assert isinstance(result.metrics, dict) assert isinstance(result.config, dict) assert result.metrics_dataframe is None assert result.config == {"some_config": 1} assert result.metrics["config"] == result.config # Load checkpoint data (see ray.rllib.algorithms.mock.MockTrainer definition) with result.checkpoint.as_directory() as checkpoint_dir: with open(os.path.join(checkpoint_dir, "mock_agent.pkl"), "rb") as f: info = pickle.load(f) assert info["info"] == 4
def _process_checkpoint( self, checkpoint_results: List[TrainingResult], decode_checkpoint_fn: Callable, ) -> None: """Ray Train entrypoint. Perform all processing for a checkpoint.""" # Get checkpoint from first worker. checkpoint_data = checkpoint_results[0].data # Decode checkpoint. checkpoint_data = decode_checkpoint_fn(checkpoint_data) score_attr = self._checkpoint_strategy.checkpoint_score_attribute if (self._checkpoint_strategy.num_to_keep != 0 and score_attr not in checkpoint_data): raise ValueError(f"Unable to persist checkpoint for " f"checkpoint_score_attribute: " f"{score_attr}. " f"Include this attribute in the call to " f"train.save_checkpoint.") tracked_checkpoint = _TrackedCheckpoint( dir_or_data=checkpoint_data, checkpoint_id=self._latest_checkpoint_id, storage_mode=CheckpointStorage.MEMORY, metrics={score_attr: checkpoint_data.get(score_attr, 0.0)}, ) self.register_checkpoint(checkpoint=tracked_checkpoint)
def __getstate__(self): state = self.__dict__.copy() # Avoid serializing the memory checkpoint. state["_newest_memory_checkpoint"] = _TrackedCheckpoint( CheckpointStorage.MEMORY, None) # Avoid serializing lambda since it may capture cyclical dependencies. state.pop("_delete_fn") return state
def test_unlimited_persistent_checkpoints(): cpm = _CheckpointManager(checkpoint_strategy=CheckpointConfig( num_to_keep=None)) for i in range(10): cpm.register_checkpoint( _TrackedCheckpoint({"data": i}, storage_mode=CheckpointStorage.PERSISTENT)) assert len(cpm._top_persisted_checkpoints) == 10
def testSameCheckpoint(self): checkpoint_manager = _CheckpointManager( keep_checkpoints_num=1, checkpoint_score_attr="i", delete_fn=lambda c: os.remove(c.dir_or_data), ) tmpfiles = [] for i in range(3): _, tmpfile = tempfile.mkstemp() with open(tmpfile, "wt") as fp: fp.write("") tmpfiles.append(tmpfile) checkpoints = [ _TrackedCheckpoint( dir_or_data=tmpfiles[0], storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(5, 5), ), _TrackedCheckpoint( dir_or_data=tmpfiles[1], storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(10, 10), ), _TrackedCheckpoint( dir_or_data=tmpfiles[2], storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(0, 0), ), _TrackedCheckpoint( dir_or_data=tmpfiles[1], storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(20, 20), ), ] for checkpoint in checkpoints: checkpoint_manager.on_checkpoint(checkpoint) self.assertTrue(os.path.exists(checkpoint.dir_or_data)) for tmpfile in tmpfiles: if os.path.exists(tmpfile): os.remove(tmpfile)
def test_persist_memory_checkpoints(): cpm = _CheckpointManager(checkpoint_strategy=CheckpointConfig( num_to_keep=None)) cpm._persist_memory_checkpoints = True for i in range(10): cpm.register_checkpoint( _TrackedCheckpoint({"data": i}, storage_mode=CheckpointStorage.MEMORY)) assert len(cpm._top_persisted_checkpoints) == 10
def testOnMemoryCheckpoint(self): checkpoints = [ _TrackedCheckpoint( dir_or_data={"a": 0}, storage_mode=CheckpointStorage.MEMORY, metrics=self.mock_result(0, 0), ), _TrackedCheckpoint( dir_or_data={"a": 0}, storage_mode=CheckpointStorage.MEMORY, metrics=self.mock_result(0, 0), ), ] checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) checkpoint_manager.on_checkpoint(checkpoints[0]) checkpoint_manager.on_checkpoint(checkpoints[1]) newest = checkpoint_manager.newest_memory_checkpoint self.assertEqual(newest, checkpoints[1]) self.assertEqual(checkpoint_manager.best_checkpoints(), [])
def write_checkpoint(trial: Trial, index: int): checkpoint_dir = TrainableUtil.make_checkpoint_dir(trial.logdir, index=index) result = {"training_iteration": index} with open(os.path.join(checkpoint_dir, "cp.json"), "w") as f: json.dump(result, f) tune_cp = _TrackedCheckpoint( dir_or_data=checkpoint_dir, storage_mode=CheckpointStorage.PERSISTENT, metrics=result, ) trial.saving_to = tune_cp return checkpoint_dir
def save( self, trial: Trial, storage: CheckpointStorage = CheckpointStorage.PERSISTENT, result: Optional[Dict] = None, ) -> _TrackedCheckpoint: """Saves the trial's state to a checkpoint asynchronously. Args: trial: The trial to be saved. storage: Where to store the checkpoint. Defaults to PERSISTENT. result: The state of this trial as a dictionary to be saved. If result is None, the trial's last result will be used. Returns: Checkpoint object, or None if an Exception occurs. """ logger.debug(f"saving trial {trial}") result = result or trial.last_result with self._change_working_directory(trial): if storage == CheckpointStorage.MEMORY: value = trial.runner.save_to_object.remote() checkpoint = _TrackedCheckpoint(dir_or_data=value, storage_mode=storage, metrics=result) trial.on_checkpoint(checkpoint) else: value = trial.runner.save.remote() checkpoint = _TrackedCheckpoint(dir_or_data=value, storage_mode=storage, metrics=result) trial.saving_to = checkpoint self._futures[value] = (_ExecutorEventType.SAVING_RESULT, trial) return checkpoint
def checkpoint(self): """Returns the most recent checkpoint. If the trial is in ERROR state, the most recent PERSISTENT checkpoint is returned. """ if self.status == Trial.ERROR: checkpoint = self.checkpoint_manager.newest_persistent_checkpoint else: checkpoint = self.checkpoint_manager.newest_checkpoint if checkpoint.dir_or_data is None: checkpoint = _TrackedCheckpoint( dir_or_data=self.restore_path, storage_mode=CheckpointStorage.PERSISTENT, ) return checkpoint
def testOnCheckpointUnavailableAttribute(self): """ Tests that an error is logged when the associated result of the checkpoint has no checkpoint score attribute. """ checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) no_attr_checkpoint = _TrackedCheckpoint( dir_or_data=0, storage_mode=CheckpointStorage.PERSISTENT, metrics={}, ) with patch.object(logger, "error") as log_error_mock: checkpoint_manager.on_checkpoint(no_attr_checkpoint) log_error_mock.assert_called_once() # The newest checkpoint should still be set despite this error. self.assertEqual(checkpoint_manager.newest_persistent_checkpoint, no_attr_checkpoint)
def test_keep_best_checkpoints(): cpm = _CheckpointManager(checkpoint_strategy=CheckpointConfig( num_to_keep=2, checkpoint_score_attribute="metric", checkpoint_score_order="min", )) cpm._persist_memory_checkpoints = True for i in range(10): cpm.register_checkpoint( _TrackedCheckpoint( {"data": i}, storage_mode=CheckpointStorage.MEMORY, metrics={"metric": i}, )) # Sorted from worst (max) to best (min) assert [ cp.tracked_checkpoint.metrics["metric"] for cp in cpm._top_persisted_checkpoints ] == [1, 0]
def test_syncer_callback_force_on_checkpoint(ray_start_2_cpus, temp_data_dirs): """Check that on_checkpoint forces syncing""" tmp_source, tmp_target = temp_data_dirs with freeze_time() as frozen: syncer_callback = TestSyncerCallback(sync_period=60, local_logdir_override=tmp_target) trial1 = MockTrial(trial_id="a", logdir=tmp_source) syncer_callback.on_trial_result(iteration=1, trials=[], trial=trial1, result={}) syncer_callback.wait_for_all() assert_file(True, tmp_target, "level0.txt") assert_file(False, tmp_target, "level0_new.txt") # Add new file to source directory with open(os.path.join(tmp_source, "level0_new.txt"), "w") as f: f.write("Data\n") assert_file(False, tmp_target, "level0_new.txt") frozen.tick(30) # Should sync as checkpoint observed syncer_callback.on_checkpoint( iteration=2, trials=[], trial=trial1, checkpoint=_TrackedCheckpoint( dir_or_data=tmp_target, storage_mode=CheckpointStorage.PERSISTENT), ) syncer_callback.wait_for_all() assert_file(True, tmp_target, "level0.txt") assert_file(True, tmp_target, "level0_new.txt")
def testBestCheckpointsOnlyNan(self): """ Tests that checkpoints with only nan priority are handled correctly. """ keep_checkpoints_num = 2 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ _TrackedCheckpoint( dir_or_data=i, storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(float("nan"), i), ) for i in range(4) ] for checkpoint in checkpoints: checkpoint_manager.on_checkpoint(checkpoint) best_checkpoints = checkpoint_manager.best_checkpoints() # best_checkpoints is sorted from worst to best self.assertEqual(len(best_checkpoints), keep_checkpoints_num) self.assertEqual(best_checkpoints[0].dir_or_data, 2) self.assertEqual(best_checkpoints[1].dir_or_data, 3)
def testBestCheckpoints(self): """ Tests that the best checkpoints are tracked and ordered correctly. """ keep_checkpoints_num = 4 checkpoints = [ _TrackedCheckpoint( dir_or_data=i, storage_mode=CheckpointStorage.PERSISTENT, metrics=self.mock_result(i, i), ) for i in range(8) ] for permutation in itertools.permutations(checkpoints): checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) for checkpoint in permutation: checkpoint_manager.on_checkpoint(checkpoint) best_checkpoints = checkpoint_manager.best_checkpoints() self.assertEqual(len(best_checkpoints), keep_checkpoints_num) for i in range(len(best_checkpoints)): self.assertEqual(best_checkpoints[i].dir_or_data, i + 4)
def checkpoint(self): return _TrackedCheckpoint( dir_or_data={"data": "None"}, storage_mode=CheckpointStorage.MEMORY, metrics={}, )
def newest_persistent_checkpoint(self): return self._latest_persisted_checkpoint or _TrackedCheckpoint( dir_or_data=None, checkpoint_id=-1, storage_mode=CheckpointStorage.PERSISTENT, )
def testCallbackSteps(self): trials = [ Trial("__fake", trial_id="one"), Trial("__fake", trial_id="two") ] for t in trials: self.trial_runner.add_trial(t) self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.PG_READY) self.trial_runner.step() # Trial 1 has been started self.assertEqual(self.callback.state["trial_start"]["iteration"], 0) self.assertEqual(self.callback.state["trial_start"]["trial"].trial_id, "one") # All these events haven't happened, yet self.assertTrue( all(k not in self.callback.state for k in [ "trial_restore", "trial_save", "trial_result", "trial_complete", "trial_fail", "experiment_end", ])) self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.PG_READY) self.trial_runner.step() # Iteration not increased yet self.assertEqual(self.callback.state["step_begin"]["iteration"], 1) # Iteration increased self.assertEqual(self.callback.state["step_end"]["iteration"], 2) # Second trial has been just started self.assertEqual(self.callback.state["trial_start"]["iteration"], 1) self.assertEqual(self.callback.state["trial_start"]["trial"].trial_id, "two") # Just a placeholder object ref for cp.value. cp = _TrackedCheckpoint( dir_or_data=ray.put(1), storage_mode=CheckpointStorage.PERSISTENT, metrics={TRAINING_ITERATION: 0}, ) trials[0].saving_to = cp # Let the first trial save a checkpoint self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.SAVING_RESULT, trial=trials[0], result={_ExecutorEvent.KEY_FUTURE_RESULT: "__checkpoint"}, ) self.trial_runner.step() self.assertEqual(self.callback.state["trial_save"]["iteration"], 2) self.assertEqual(self.callback.state["trial_save"]["trial"].trial_id, "one") # Let the second trial send a result result = {TRAINING_ITERATION: 1, "metric": 800, "done": False} self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.TRAINING_RESULT, trial=trials[1], result={"future_result": result}, ) self.assertTrue(not trials[1].has_reported_at_least_once) self.trial_runner.step() self.assertEqual(self.callback.state["trial_result"]["iteration"], 3) self.assertEqual(self.callback.state["trial_result"]["trial"].trial_id, "two") self.assertEqual( self.callback.state["trial_result"]["result"]["metric"], 800) self.assertEqual(trials[1].last_result["metric"], 800) # Let the second trial restore from a checkpoint trials[1].restoring_from = cp self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.RESTORING_RESULT, trial=trials[1]) self.trial_runner.step() self.assertEqual(self.callback.state["trial_restore"]["iteration"], 4) self.assertEqual( self.callback.state["trial_restore"]["trial"].trial_id, "two") # Let the second trial finish trials[1].restoring_from = None self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.TRAINING_RESULT, trial=trials[1], result={ _ExecutorEvent.KEY_FUTURE_RESULT: { TRAINING_ITERATION: 2, "metric": 900, "done": True, } }, ) self.trial_runner.step() self.assertEqual(self.callback.state["trial_complete"]["iteration"], 5) self.assertEqual( self.callback.state["trial_complete"]["trial"].trial_id, "two") # Let the first trial error self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.ERROR, trial=trials[0], result={_ExecutorEvent.KEY_EXCEPTION: Exception()}, ) self.trial_runner.step() self.assertEqual(self.callback.state["trial_fail"]["iteration"], 6) self.assertEqual(self.callback.state["trial_fail"]["trial"].trial_id, "one")
def newest_memory_checkpoint(self): return self._latest_memory_checkpoint or _TrackedCheckpoint( dir_or_data=None, checkpoint_id=-1, storage_mode=CheckpointStorage.MEMORY, )