示例#1
0
def test_checkpoint_recorder_to_json(loss_labels):
    """Test that JSON output is of the expected format with correct data"""
    start_step = 0
    max_len = 5
    cr = CheckpointRecording(start_step, f"mock/ckpt_{start_step}/path",
                             max_len)

    cr.losses = {
        s: random_result(loss_labels)
        for s in range(cr.start_step, cr.start_step + cr.max_len)
    }
    cr.accuracies = {
        s: random_result(loss_labels)
        for s in range(cr.start_step, cr.start_step + cr.max_len)
    }

    cj = cr.to_json()

    for k, v in cj["results"].items():
        assert v == cr.results[k]
示例#2
0
def test_log_recorder_to_json(num_checkpoints, recording_length, loss_labels):
    """Test that JSON output is of the expected fornmat with correct data"""

    rec = LogRecorder(recording_length, "/mock/recording/path",
                      "/mock/recording/path/ckpt")

    start_steps = [i * 2 * recording_length for i in range(num_checkpoints)]
    ckpt_recordings = [
        CheckpointRecording(s, f"/mock/ckpt_{i}/path", recording_length)
        for i, s in enumerate(start_steps)
    ]

    for cr in ckpt_recordings:
        cr._results = {
            s: random_result(loss_labels)
            for s in range(cr.start_step, cr.start_step + cr.max_len)
        }
        rec.checkpoint_logs[cr.start_step] = cr

    # Sanity check the mocking code above
    assert len(rec.checkpoint_logs) == num_checkpoints

    json_output = rec.to_json()
    assert len(json_output["ckpt_logs"]) == num_checkpoints

    for i, log_key in enumerate(json_output["ckpt_logs"]):
        cj = json_output["ckpt_logs"][log_key]
        cr = rec.checkpoint_logs[log_key]
        assert cj["checkpoint"] == f"/mock/ckpt_{i}/path"
        assert cj["start_step"] == start_steps[i]

        assert len(cj["results"]) == len(cr._results)

        for step, result in cj["results"].items():
            assert result["losses"] == cr._results[step]["losses"]
            assert result["accuracies"] == cr._results[step]["accuracies"]
示例#3
0
def test_log_recorder_updates(caplog):
    """Ensure correct update behaviour:
        1) If we try to record without having initialised a post-checkpoint recording - warn and ignore
        2) If we have a recorder set, call its record method (internals tested in a separate test - this should be mocked)
        3) If the recorder is full, raise a relevant exception.
    """
    MAX_LEN = 5

    def check_rec(rec, e_step, e_loss, e_acc):
        assert rec.is_recording()
        assert rec._current_recorder.last_step == e_step
        assert rec._current_recorder.metrics["loss"] == e_loss
        assert rec._current_recorder.metrics["accuracy"] == e_acc

    # Stub out the stop and save methods -> we're not testing their behaviour here, just
    # need to make sure they're called correctly for (3)
    mock_status = {"recording_stopped": False, "recording_saved": False}

    def mock_stop_recording():
        mock_status["recording_stopped"] = True

    def mock_save_recording():
        mock_status["recording_saved"] = True

    # Start test
    rec = LogRecorder(MAX_LEN, "/mock/recording/path",
                      "/mock/recording/path/ckpt")
    rec.update_step(0)

    rec.stop_recording = mock_stop_recording
    rec.save = mock_save_recording

    # (1)
    with caplog.at_level(logging.INFO, logger='convergence-harness'):
        rec.record_step_metric("some_metric", {})
        expected_warning = "Trying to record step 0, but recorder is None. Skipping entry."
        assert caplog.records[0].message == expected_warning

    mock_loss = random_result()
    mock_accuracy = random_result()

    cr = MockCheckpointRecording(0)
    rec._current_recorder = cr

    # (2)
    rec.update_step(1)
    assert rec.current_step == 1
    rec.record_step_metric("loss", mock_loss)
    rec.record_step_metric("accuracy", mock_accuracy)
    check_rec(rec, 1, mock_loss, mock_accuracy)

    rec.update_step(5)
    assert rec.current_step == 5

    mock_loss = random_result()
    mock_accuracy = random_result()

    rec.record_step_metric("loss", mock_loss)
    rec.record_step_metric("accuracy", mock_accuracy)
    check_rec(rec, 5, mock_loss, mock_accuracy)

    assert mock_status["recording_stopped"] is False
    assert mock_status["recording_saved"] is False

    # (3)
    rec._current_recorder._can_add = False
    with pytest.raises(RecorderIsFullException):
        rec.record_step_metric("loss", mock_loss)
        rec.record_step_metric("accuracy", mock_accuracy)