Ejemplo n.º 1
0
def test_log_recorder_is_recording():
    """Ensure that is_recording is correctly set depending on whether a recorder has been started"""
    rec = LogRecorder(4, "/mock/recording/path", "/mock/recording/path/ckpt")
    assert rec._current_recorder is None
    assert rec.is_recording() is False

    rec.start_recording("/mock/ckpt/path")
    assert rec._current_recorder is not None
    assert rec.is_recording() is True

    rec.stop_recording()
    assert rec._current_recorder is None
    assert rec.is_recording() is False
Ejemplo n.º 2
0
def create_gather_process(output_path, manifest_name, full_config, process_cmd,
                          ckpt_root):
    # Disable output buffering to make sure we get the process logs as they're recorded.
    os.environ['PYTHONUNBUFFERED'] = "1"
    proc = subprocess.Popen(
        process_cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        universal_newlines=True,
    )

    log_recorder = LogRecorder(
        full_config["recording"]["steps_to_record_after_save"],
        manifest_path(output_path, manifest_name), ckpt_root)
    try:
        while proc.poll() is None:
            if full_config["log_output"] == "stderr":
                line = proc.stderr.readline()
            else:
                line = proc.stdout.readline()
            gather_log_handler(line, full_config["log_parsing"],
                               full_config["metrics"], log_recorder)
    finally:
        log_recorder.stop_recording()
        log_recorder.save()
        proc.kill()
Ejemplo n.º 3
0
def test_log_recorder_cannot_add_should_stop(monkeypatch):
    """
    If the checkpoint recorder cannot add another step (due to being full),
    the log recorder should stop and save the checkpoint recorder
    """
    rec = LogRecorder(1, "/mock/recording/path", "/mock/recording/path/ckpt")

    stopped_calls = 0
    saved_calls = 0

    def mock_stop_recording():
        nonlocal stopped_calls
        stopped_calls += 1

    def mock_save():
        nonlocal saved_calls
        saved_calls += 1

    monkeypatch.setattr(rec, "stop_recording", mock_stop_recording)
    monkeypatch.setattr(rec, "save", mock_save)

    rec.start_recording("/mock/ckpt/path")
    rec.update_step(0)
    assert rec._current_recorder.can_add_step(0) is True
    assert stopped_calls == 0
    assert saved_calls == 0

    rec.record_step_metric("metric", {"value": 0})
    assert rec._current_recorder.can_add_step(0) is True
    assert rec._current_recorder.can_add_step(1) is False
    assert stopped_calls == 0
    assert saved_calls == 0

    rec.update_step(1)
    assert stopped_calls == 1
    assert saved_calls == 1
Ejemplo n.º 4
0
def test_log_recorder_to_json(num_checkpoints, recording_length, loss_labels):
    """Test that JSON output is of the expected fornmat with correct data"""

    rec = LogRecorder(recording_length, "/mock/recording/path",
                      "/mock/recording/path/ckpt")

    start_steps = [i * 2 * recording_length for i in range(num_checkpoints)]
    ckpt_recordings = [
        CheckpointRecording(s, f"/mock/ckpt_{i}/path", recording_length)
        for i, s in enumerate(start_steps)
    ]

    for cr in ckpt_recordings:
        cr._results = {
            s: random_result(loss_labels)
            for s in range(cr.start_step, cr.start_step + cr.max_len)
        }
        rec.checkpoint_logs[cr.start_step] = cr

    # Sanity check the mocking code above
    assert len(rec.checkpoint_logs) == num_checkpoints

    json_output = rec.to_json()
    assert len(json_output["ckpt_logs"]) == num_checkpoints

    for i, log_key in enumerate(json_output["ckpt_logs"]):
        cj = json_output["ckpt_logs"][log_key]
        cr = rec.checkpoint_logs[log_key]
        assert cj["checkpoint"] == f"/mock/ckpt_{i}/path"
        assert cj["start_step"] == start_steps[i]

        assert len(cj["results"]) == len(cr._results)

        for step, result in cj["results"].items():
            assert result["losses"] == cr._results[step]["losses"]
            assert result["accuracies"] == cr._results[step]["accuracies"]
Ejemplo n.º 5
0
def test_log_recorder_updates(caplog):
    """Ensure correct update behaviour:
        1) If we try to record without having initialised a post-checkpoint recording - warn and ignore
        2) If we have a recorder set, call its record method (internals tested in a separate test - this should be mocked)
        3) If the recorder is full, raise a relevant exception.
    """
    MAX_LEN = 5

    def check_rec(rec, e_step, e_loss, e_acc):
        assert rec.is_recording()
        assert rec._current_recorder.last_step == e_step
        assert rec._current_recorder.metrics["loss"] == e_loss
        assert rec._current_recorder.metrics["accuracy"] == e_acc

    # Stub out the stop and save methods -> we're not testing their behaviour here, just
    # need to make sure they're called correctly for (3)
    mock_status = {"recording_stopped": False, "recording_saved": False}

    def mock_stop_recording():
        mock_status["recording_stopped"] = True

    def mock_save_recording():
        mock_status["recording_saved"] = True

    # Start test
    rec = LogRecorder(MAX_LEN, "/mock/recording/path",
                      "/mock/recording/path/ckpt")
    rec.update_step(0)

    rec.stop_recording = mock_stop_recording
    rec.save = mock_save_recording

    # (1)
    with caplog.at_level(logging.INFO, logger='convergence-harness'):
        rec.record_step_metric("some_metric", {})
        expected_warning = "Trying to record step 0, but recorder is None. Skipping entry."
        assert caplog.records[0].message == expected_warning

    mock_loss = random_result()
    mock_accuracy = random_result()

    cr = MockCheckpointRecording(0)
    rec._current_recorder = cr

    # (2)
    rec.update_step(1)
    assert rec.current_step == 1
    rec.record_step_metric("loss", mock_loss)
    rec.record_step_metric("accuracy", mock_accuracy)
    check_rec(rec, 1, mock_loss, mock_accuracy)

    rec.update_step(5)
    assert rec.current_step == 5

    mock_loss = random_result()
    mock_accuracy = random_result()

    rec.record_step_metric("loss", mock_loss)
    rec.record_step_metric("accuracy", mock_accuracy)
    check_rec(rec, 5, mock_loss, mock_accuracy)

    assert mock_status["recording_stopped"] is False
    assert mock_status["recording_saved"] is False

    # (3)
    rec._current_recorder._can_add = False
    with pytest.raises(RecorderIsFullException):
        rec.record_step_metric("loss", mock_loss)
        rec.record_step_metric("accuracy", mock_accuracy)
Ejemplo n.º 6
0
def test_log_recorder_start_recording(caplog):
    """Correct behaviour for starting a recording:
        1) Creating the checkpoint recorder correctly
        2) Stopping stores the recorder and clears ready for the next start
        3) If a run is still recording, stop the previous one and start anew"""

    MAX_LEN = 5

    def check_ckpt_recorder(rec, e_step, e_path, e_max_length):
        assert rec.is_recording()
        assert rec._current_recorder.start_step == e_step
        assert rec._current_recorder.checkpoint_path == e_path
        assert rec._current_recorder.max_len == e_max_length

    storage_path = "/mock/recording/path/ckpt"
    rel_ckpt_paths = [f"mock_ckpt_{c}/path" for c in range(3)]
    full_ckpt_paths = [os.path.join(storage_path, c) for c in rel_ckpt_paths]

    current_step = 5

    rec = LogRecorder(MAX_LEN, "/mock/recording/path",
                      "/mock/recording/path/ckpt")
    rec.update_step(current_step)

    # (1)
    rec.start_recording(full_ckpt_paths[0])
    check_ckpt_recorder(rec, current_step, rel_ckpt_paths[0], MAX_LEN)

    # (2)
    rec.stop_recording()
    assert rec._current_recorder is None
    assert current_step in rec.checkpoint_logs

    current_step = 10
    rec.update_step(current_step)
    rec.start_recording(full_ckpt_paths[1])
    check_ckpt_recorder(rec, current_step, rel_ckpt_paths[1], MAX_LEN)

    last_step = current_step
    current_step = 20
    rec.update_step(current_step)
    # (3)
    with caplog.at_level(logging.INFO, logger='convergence-harness'):
        rec.start_recording(full_ckpt_paths[2])
        expected_warning = "Already recording logs for the previous checkpoint. Stopping here and starting a fresh log."
        assert caplog.records[0].message == expected_warning
    assert last_step in rec.checkpoint_logs
    assert current_step not in rec.checkpoint_logs
    check_ckpt_recorder(rec, current_step, rel_ckpt_paths[2], MAX_LEN)

    rec.stop_recording()
    assert rec._current_recorder is None
    assert current_step in rec.checkpoint_logs