示例#1
0
    def __init__(
        self,
        keep_checkpoints_num: int,
        checkpoint_score_attr: Optional[str],
        delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]] = None,
    ):
        if keep_checkpoints_num == 0:
            raise RuntimeError(
                "If checkpointing is enabled, Ray Tune requires `keep_checkpoints_num` "
                "to be None or a number greater than 0")

        checkpoint_score_attr = checkpoint_score_attr or TRAINING_ITERATION

        checkpoint_score_desc = checkpoint_score_attr.startswith("min-")
        if checkpoint_score_desc:
            checkpoint_score_attr = checkpoint_score_attr[4:]
        else:
            checkpoint_score_attr = checkpoint_score_attr

        checkpoint_strategy = CheckpointConfig(
            num_to_keep=keep_checkpoints_num,
            checkpoint_score_attribute=checkpoint_score_attr,
            checkpoint_score_order=MIN if checkpoint_score_desc else MAX,
        )

        super().__init__(checkpoint_strategy=checkpoint_strategy,
                         delete_fn=delete_fn)
示例#2
0
文件: config.py 项目: parasj/ray
    def __post_init__(self):
        if not self.failure_config:
            self.failure_config = FailureConfig()

        if not self.sync_config:
            self.sync_config = SyncConfig()

        if not self.checkpoint_config:
            self.checkpoint_config = CheckpointConfig()
示例#3
0
def test_unlimited_persistent_checkpoints():
    cpm = _CheckpointManager(checkpoint_strategy=CheckpointConfig(
        num_to_keep=None))

    for i in range(10):
        cpm.register_checkpoint(
            _TrackedCheckpoint({"data": i},
                               storage_mode=CheckpointStorage.PERSISTENT))

    assert len(cpm._top_persisted_checkpoints) == 10
示例#4
0
def test_persist_memory_checkpoints():
    cpm = _CheckpointManager(checkpoint_strategy=CheckpointConfig(
        num_to_keep=None))
    cpm._persist_memory_checkpoints = True

    for i in range(10):
        cpm.register_checkpoint(
            _TrackedCheckpoint({"data": i},
                               storage_mode=CheckpointStorage.MEMORY))

    assert len(cpm._top_persisted_checkpoints) == 10
示例#5
0
    def on_start_training(
        self,
        checkpoint_strategy: Optional[CheckpointConfig],
        run_dir: Path,
        latest_checkpoint_id: Optional[int] = 0,
    ):
        checkpoint_strategy = checkpoint_strategy or CheckpointConfig()
        self._checkpoint_strategy = checkpoint_strategy

        self._validate_checkpoint_strategy()

        self.run_dir = run_dir
        self._latest_checkpoint_id = latest_checkpoint_id or 0
示例#6
0
def test_keep_best_checkpoints():
    cpm = _CheckpointManager(checkpoint_strategy=CheckpointConfig(
        num_to_keep=2,
        checkpoint_score_attribute="metric",
        checkpoint_score_order="min",
    ))
    cpm._persist_memory_checkpoints = True

    for i in range(10):
        cpm.register_checkpoint(
            _TrackedCheckpoint(
                {"data": i},
                storage_mode=CheckpointStorage.MEMORY,
                metrics={"metric": i},
            ))

    # Sorted from worst (max) to best (min)
    assert [
        cp.tracked_checkpoint.metrics["metric"]
        for cp in cpm._top_persisted_checkpoints
    ] == [1, 0]