def test_checkpoint_recorder_to_json(loss_labels): """Test that JSON output is of the expected format with correct data""" start_step = 0 max_len = 5 cr = CheckpointRecording(start_step, f"mock/ckpt_{start_step}/path", max_len) cr.losses = { s: random_result(loss_labels) for s in range(cr.start_step, cr.start_step + cr.max_len) } cr.accuracies = { s: random_result(loss_labels) for s in range(cr.start_step, cr.start_step + cr.max_len) } cj = cr.to_json() for k, v in cj["results"].items(): assert v == cr.results[k]
def test_log_recorder_to_json(num_checkpoints, recording_length, loss_labels): """Test that JSON output is of the expected fornmat with correct data""" rec = LogRecorder(recording_length, "/mock/recording/path", "/mock/recording/path/ckpt") start_steps = [i * 2 * recording_length for i in range(num_checkpoints)] ckpt_recordings = [ CheckpointRecording(s, f"/mock/ckpt_{i}/path", recording_length) for i, s in enumerate(start_steps) ] for cr in ckpt_recordings: cr._results = { s: random_result(loss_labels) for s in range(cr.start_step, cr.start_step + cr.max_len) } rec.checkpoint_logs[cr.start_step] = cr # Sanity check the mocking code above assert len(rec.checkpoint_logs) == num_checkpoints json_output = rec.to_json() assert len(json_output["ckpt_logs"]) == num_checkpoints for i, log_key in enumerate(json_output["ckpt_logs"]): cj = json_output["ckpt_logs"][log_key] cr = rec.checkpoint_logs[log_key] assert cj["checkpoint"] == f"/mock/ckpt_{i}/path" assert cj["start_step"] == start_steps[i] assert len(cj["results"]) == len(cr._results) for step, result in cj["results"].items(): assert result["losses"] == cr._results[step]["losses"] assert result["accuracies"] == cr._results[step]["accuracies"]
def test_log_recorder_updates(caplog): """Ensure correct update behaviour: 1) If we try to record without having initialised a post-checkpoint recording - warn and ignore 2) If we have a recorder set, call its record method (internals tested in a separate test - this should be mocked) 3) If the recorder is full, raise a relevant exception. """ MAX_LEN = 5 def check_rec(rec, e_step, e_loss, e_acc): assert rec.is_recording() assert rec._current_recorder.last_step == e_step assert rec._current_recorder.metrics["loss"] == e_loss assert rec._current_recorder.metrics["accuracy"] == e_acc # Stub out the stop and save methods -> we're not testing their behaviour here, just # need to make sure they're called correctly for (3) mock_status = {"recording_stopped": False, "recording_saved": False} def mock_stop_recording(): mock_status["recording_stopped"] = True def mock_save_recording(): mock_status["recording_saved"] = True # Start test rec = LogRecorder(MAX_LEN, "/mock/recording/path", "/mock/recording/path/ckpt") rec.update_step(0) rec.stop_recording = mock_stop_recording rec.save = mock_save_recording # (1) with caplog.at_level(logging.INFO, logger='convergence-harness'): rec.record_step_metric("some_metric", {}) expected_warning = "Trying to record step 0, but recorder is None. Skipping entry." assert caplog.records[0].message == expected_warning mock_loss = random_result() mock_accuracy = random_result() cr = MockCheckpointRecording(0) rec._current_recorder = cr # (2) rec.update_step(1) assert rec.current_step == 1 rec.record_step_metric("loss", mock_loss) rec.record_step_metric("accuracy", mock_accuracy) check_rec(rec, 1, mock_loss, mock_accuracy) rec.update_step(5) assert rec.current_step == 5 mock_loss = random_result() mock_accuracy = random_result() rec.record_step_metric("loss", mock_loss) rec.record_step_metric("accuracy", mock_accuracy) check_rec(rec, 5, mock_loss, mock_accuracy) assert mock_status["recording_stopped"] is False assert mock_status["recording_saved"] is False # (3) rec._current_recorder._can_add = False with pytest.raises(RecorderIsFullException): rec.record_step_metric("loss", mock_loss) rec.record_step_metric("accuracy", mock_accuracy)