def __init__( self, keep_checkpoints_num: int, checkpoint_score_attr: Optional[str], delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] = None, ): if keep_checkpoints_num == 0: raise RuntimeError( "If checkpointing is enabled, Ray Tune requires `keep_checkpoints_num` " "to be None or a number greater than 0") checkpoint_score_attr = checkpoint_score_attr or TRAINING_ITERATION checkpoint_score_desc = checkpoint_score_attr.startswith("min-") if checkpoint_score_desc: checkpoint_score_attr = checkpoint_score_attr[4:] else: checkpoint_score_attr = checkpoint_score_attr checkpoint_strategy = CheckpointStrategy( num_to_keep=keep_checkpoints_num, checkpoint_score_attribute=checkpoint_score_attr, checkpoint_score_order=MIN if checkpoint_score_desc else MAX, ) super().__init__(checkpoint_strategy=checkpoint_strategy, delete_fn=delete_fn)
def test_unlimited_persistent_checkpoints(): cpm = _CheckpointManager(checkpoint_strategy=CheckpointStrategy( num_to_keep=None)) for i in range(10): cpm.register_checkpoint( _TrackedCheckpoint({"data": i}, storage_mode=CheckpointStorage.PERSISTENT)) assert len(cpm._top_persisted_checkpoints) == 10
def test_persist_memory_checkpoints(): cpm = _CheckpointManager(checkpoint_strategy=CheckpointStrategy( num_to_keep=None)) cpm._persist_memory_checkpoints = True for i in range(10): cpm.register_checkpoint( _TrackedCheckpoint({"data": i}, storage_mode=CheckpointStorage.MEMORY)) assert len(cpm._top_persisted_checkpoints) == 10
def on_start_training( self, checkpoint_strategy: Optional[CheckpointStrategy], run_dir: Path, latest_checkpoint_id: Optional[int] = 0, ): checkpoint_strategy = checkpoint_strategy or CheckpointStrategy() self._checkpoint_strategy = checkpoint_strategy self._validate_checkpoint_strategy() self.run_dir = run_dir self._latest_checkpoint_id = latest_checkpoint_id or 0
def test_keep_best_checkpoints(): cpm = _CheckpointManager(checkpoint_strategy=CheckpointStrategy( num_to_keep=2, checkpoint_score_attribute="metric", checkpoint_score_order="min", )) cpm._persist_memory_checkpoints = True for i in range(10): cpm.register_checkpoint( _TrackedCheckpoint( {"data": i}, storage_mode=CheckpointStorage.MEMORY, metrics={"metric": i}, )) # Sorted from worst (max) to best (min) assert [ cp.tracked_checkpoint.metrics["metric"] for cp in cpm._top_persisted_checkpoints ] == [1, 0]