def test_download_checkpoints(test_output_dirs: OutputFolderForTests, is_ensemble: bool, runner_config: AzureConfig) -> None: output_dir = test_output_dirs.root_dir assert get_results_blob_path("some_run_id") == "azureml/ExperimentRun/dcid.some_run_id" # Any recent run ID from a PR build will do. Use a PR build because the checkpoint files are small there. config = ModelConfigBase(should_validate=False) config.set_output_to(output_dir) runner_config.run_recovery_id = DEFAULT_ENSEMBLE_RUN_RECOVERY_ID if is_ensemble else DEFAULT_RUN_RECOVERY_ID run_recovery = RunRecovery.download_checkpoints_from_recovery_run(runner_config, config) run_to_recover = fetch_run(workspace=runner_config.get_workspace(), run_recovery_id=runner_config.run_recovery_id) expected_checkpoint_file = "1" + CHECKPOINT_FILE_SUFFIX if is_ensemble: child_runs = fetch_child_runs(run_to_recover) expected_files = [config.checkpoint_folder / OTHER_RUNS_SUBDIR_NAME / str(x.get_tags()['cross_validation_split_index']) / expected_checkpoint_file for x in child_runs] else: expected_files = [config.checkpoint_folder / run_to_recover.id / expected_checkpoint_file] checkpoint_paths = run_recovery.get_checkpoint_paths(1) if is_ensemble: assert len(run_recovery.checkpoints_roots) == len(expected_files) assert all([(x in [y.parent for y in expected_files]) for x in run_recovery.checkpoints_roots]) assert len(checkpoint_paths) == len(expected_files) assert all([x in expected_files for x in checkpoint_paths]) else: assert len(checkpoint_paths) == 1 assert checkpoint_paths[0] == expected_files[0] assert all([expected_file.exists() for expected_file in expected_files])
def test_download_checkpoints_hyperdrive_run(test_output_dirs: OutputFolderForTests, runner_config: AzureConfig) -> None: output_dir = test_output_dirs.root_dir config = ModelConfigBase(should_validate=False) config.set_output_to(output_dir) runner_config.run_recovery_id = DEFAULT_ENSEMBLE_RUN_RECOVERY_ID child_runs = fetch_child_runs(run=fetch_run(runner_config.get_workspace(), DEFAULT_ENSEMBLE_RUN_RECOVERY_ID)) # recover child runs separately also to test hyperdrive child run recovery functionality expected_checkpoint_file = "1" + CHECKPOINT_FILE_SUFFIX for child in child_runs: expected_files = [config.checkpoint_folder / child.id / expected_checkpoint_file] run_recovery = RunRecovery.download_checkpoints_from_recovery_run(runner_config, config, child) assert all([x in expected_files for x in run_recovery.get_checkpoint_paths(epoch=1)]) assert all([expected_file.exists() for expected_file in expected_files])