Ejemplo n.º 1
0
def test_checkpoint_recorder_to_json(loss_labels):
    """Test that JSON output is of the expected format with correct data"""
    start_step = 0
    max_len = 5
    cr = CheckpointRecording(start_step, f"mock/ckpt_{start_step}/path",
                             max_len)

    cr.losses = {
        s: random_result(loss_labels)
        for s in range(cr.start_step, cr.start_step + cr.max_len)
    }
    cr.accuracies = {
        s: random_result(loss_labels)
        for s in range(cr.start_step, cr.start_step + cr.max_len)
    }

    cj = cr.to_json()

    for k, v in cj["results"].items():
        assert v == cr.results[k]
Ejemplo n.º 2
0
def test_checkpoint_recorder_can_add_step():
    """Ensure that the checkpoint recording stores data and correctly observes the max log length."""
    max_len = 10
    current_step = 0
    cr = CheckpointRecording(10, f"mock/ckpt_10/path", max_len)
    assert cr.can_add_step(current_step) is True
    cr._recorded_steps.add(current_step)
    assert cr.can_add_step(current_step) is True
    assert cr.can_add_step(current_step + 1) is True

    for current_step in range(1, max_len):
        cr._recorded_steps.add(current_step)

    assert cr.can_add_step(current_step) is True
    assert cr.can_add_step(current_step + 1) is False
    assert cr.can_add_step(current_step + 2) is False
Ejemplo n.º 3
0
def test_checkpoint_recorder_record(max_len, num_losses):
    """Ensure that the checkpoint recording stores data and correctly observes the max log length."""
    start_step = 0
    cr = CheckpointRecording(start_step, f"mock/ckpt_{start_step}/path",
                             max_len)

    step_losses = np.random.random(size=(max_len, num_losses))
    step_accuracies = np.random.random(size=(max_len, num_losses))

    for i in range(max_len):
        losses = {j: step_losses[i, j] for j in range(num_losses)}
        accuracies = {j: step_accuracies[i, j] for j in range(num_losses)}
        print(f"Recording step: {start_step + i}")
        print(f"Current length: {cr.current_len}")
        cr.record_metric(start_step + i, "loss", losses)
        cr.record_metric(start_step + i, "accuracy", accuracies)

    assert cr.current_len == max_len

    with pytest.raises(RecorderIsFullException):
        cr.record_metric(start_step + max_len, {}, {})
Ejemplo n.º 4
0
def test_log_recorder_to_json(num_checkpoints, recording_length, loss_labels):
    """Test that JSON output is of the expected fornmat with correct data"""

    rec = LogRecorder(recording_length, "/mock/recording/path",
                      "/mock/recording/path/ckpt")

    start_steps = [i * 2 * recording_length for i in range(num_checkpoints)]
    ckpt_recordings = [
        CheckpointRecording(s, f"/mock/ckpt_{i}/path", recording_length)
        for i, s in enumerate(start_steps)
    ]

    for cr in ckpt_recordings:
        cr._results = {
            s: random_result(loss_labels)
            for s in range(cr.start_step, cr.start_step + cr.max_len)
        }
        rec.checkpoint_logs[cr.start_step] = cr

    # Sanity check the mocking code above
    assert len(rec.checkpoint_logs) == num_checkpoints

    json_output = rec.to_json()
    assert len(json_output["ckpt_logs"]) == num_checkpoints

    for i, log_key in enumerate(json_output["ckpt_logs"]):
        cj = json_output["ckpt_logs"][log_key]
        cr = rec.checkpoint_logs[log_key]
        assert cj["checkpoint"] == f"/mock/ckpt_{i}/path"
        assert cj["start_step"] == start_steps[i]

        assert len(cj["results"]) == len(cr._results)

        for step, result in cj["results"].items():
            assert result["losses"] == cr._results[step]["losses"]
            assert result["accuracies"] == cr._results[step]["accuracies"]
Ejemplo n.º 5
0
def test_checkpoint_recorder_record_metric_current_len():
    """The length of the recorded should only change when a new step is added, not
    when a new metric is added."""
    max_len = 10
    cr = CheckpointRecording(10, f"mock/ckpt_10/path", max_len)
    assert cr.current_len == 0
    cr.record_metric(0, "metric_0", random.random())
    assert cr.current_len == 1
    cr.record_metric(0, "metric_1", random.random())
    assert cr.current_len == 1
    cr.record_metric(0, "metric_2", random.random())
    assert cr.current_len == 1
    cr.record_metric(1, "metric_0", random.random())
    assert cr.current_len == 2
    cr.record_metric(1, "metric_1", random.random())
    assert cr.current_len == 2
    cr.record_metric(2, "metric_2", random.random())
    assert cr.current_len == 3