Esempio n. 1
0
def test_unlimited_persistent_checkpoints():
    cpm = _CheckpointManager(checkpoint_strategy=CheckpointConfig(
        num_to_keep=None))

    for i in range(10):
        cpm.register_checkpoint(
            _TrackedCheckpoint({"data": i},
                               storage_mode=CheckpointStorage.PERSISTENT))

    assert len(cpm._top_persisted_checkpoints) == 10
Esempio n. 2
0
def test_persist_memory_checkpoints():
    cpm = _CheckpointManager(checkpoint_strategy=CheckpointConfig(
        num_to_keep=None))
    cpm._persist_memory_checkpoints = True

    for i in range(10):
        cpm.register_checkpoint(
            _TrackedCheckpoint({"data": i},
                               storage_mode=CheckpointStorage.MEMORY))

    assert len(cpm._top_persisted_checkpoints) == 10
Esempio n. 3
0
def test_keep_best_checkpoints():
    cpm = _CheckpointManager(checkpoint_strategy=CheckpointConfig(
        num_to_keep=2,
        checkpoint_score_attribute="metric",
        checkpoint_score_order="min",
    ))
    cpm._persist_memory_checkpoints = True

    for i in range(10):
        cpm.register_checkpoint(
            _TrackedCheckpoint(
                {"data": i},
                storage_mode=CheckpointStorage.MEMORY,
                metrics={"metric": i},
            ))

    # Sorted from worst (max) to best (min)
    assert [
        cp.tracked_checkpoint.metrics["metric"]
        for cp in cpm._top_persisted_checkpoints
    ] == [1, 0]