def test_get_checkpoints_from_model_ensemble_run(test_output_dirs: OutputFolderForTests) -> None:
    model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=FALLBACK_ENSEMBLE_RUN)

    downloaded_checkpoints = CheckpointHandler.get_checkpoints_from_model(model_id=model_id,
                                                                          workspace=get_default_workspace(),
                                                                          download_path=test_output_dirs.root_dir)
    # Check that all the ensemble checkpoints have been downloaded
    expected_model_root = test_output_dirs.root_dir / FINAL_ENSEMBLE_MODEL_FOLDER
    assert expected_model_root.is_dir()
    model_inference_config = read_model_inference_config(expected_model_root / MODEL_INFERENCE_JSON_FILE_NAME)
    expected_paths = [expected_model_root / x for x in model_inference_config.checkpoint_paths]

    assert len(expected_paths) == len(downloaded_checkpoints)
    assert set(expected_paths) == set(downloaded_checkpoints)
    for expected_path in expected_paths:
        assert expected_path.is_file()
def test_get_checkpoints_from_model_single_run(test_output_dirs: OutputFolderForTests) -> None:
    model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN)

    downloaded_checkpoints = CheckpointHandler.get_checkpoints_from_model(model_id=model_id,
                                                                          workspace=get_default_workspace(),
                                                                          download_path=test_output_dirs.root_dir)
    # Check a single checkpoint has been downloaded
    expected_model_root = test_output_dirs.root_dir / FINAL_MODEL_FOLDER
    assert expected_model_root.is_dir()
    model_inference_config = read_model_inference_config(expected_model_root / MODEL_INFERENCE_JSON_FILE_NAME)
    expected_paths = [expected_model_root / x for x in model_inference_config.checkpoint_paths]

    assert len(expected_paths) == 1  # A registered model for a non-ensemble run should contain only one checkpoint
    assert len(downloaded_checkpoints) == 1
    assert expected_paths[0] == downloaded_checkpoints[0]
    assert expected_paths[0].is_file()