def download_checkpoints_from_recovery_run(azure_config: AzureConfig,
                                               config: DeepLearningConfig,
                                               run_context: Optional[Run] = None) -> RunRecovery:
        """
        Downloads checkpoints of run corresponding to the run_recovery_id in azure_config, and any
        checkpoints of the child runs if they exist.

        :param azure_config: Azure related configs.
        :param config: Model related configs.
        :param run_context: Context of the current run (will be used to find the target AML workspace)
        :return:RunRecovery
        """
        run_context = run_context or RUN_CONTEXT
        workspace = azure_config.get_workspace()

        # Find the run to recover in AML workspace
        if not azure_config.run_recovery_id:
            raise ValueError("A valid run_recovery_id is required to download recovery checkpoints, found None")

        run_to_recover = fetch_run(workspace, azure_config.run_recovery_id.strip())
        # Handle recovery of a HyperDrive cross validation run (from within a successor HyperDrive run,
        # not in ensemble creation). In this case, run_recovery_id refers to the parent prior run, so we
        # need to set run_to_recover to the child of that run whose split index is the same as that of
        # the current (child) run.
        if is_cross_validation_child_run(run_context):
            run_to_recover = next(x for x in fetch_child_runs(run_to_recover) if
                                  get_cross_validation_split_index(x) == get_cross_validation_split_index(run_context))

        return RunRecovery.download_checkpoints_from_run(config, run_to_recover)
Example #2
0
def test_get_cross_validation_split_index_single_run() -> None:
    """
    Test that retrieved cross validation split index is as expected, for single runs.
    """
    run = get_most_recent_run()
    # check for offline run
    assert get_cross_validation_split_index(
        Run.get_context()) == DEFAULT_CROSS_VALIDATION_SPLIT_INDEX
    # check for online runs
    assert get_cross_validation_split_index(
        run) == DEFAULT_CROSS_VALIDATION_SPLIT_INDEX
Example #3
0
def test_get_cross_validation_split_index_ensemble_run() -> None:
    """
    Test that retrieved cross validation split index is as expected, for ensembles.
    """
    # check for offline run
    assert get_cross_validation_split_index(
        Run.get_context()) == DEFAULT_CROSS_VALIDATION_SPLIT_INDEX
    # check for online runs
    run = get_most_recent_run(
        fallback_run_id_for_local_execution=FALLBACK_ENSEMBLE_RUN)
    assert get_cross_validation_split_index(
        run) == DEFAULT_CROSS_VALIDATION_SPLIT_INDEX
    assert all([
        get_cross_validation_split_index(x) >
        DEFAULT_CROSS_VALIDATION_SPLIT_INDEX for x in fetch_child_runs(run)
    ])
def test_get_cross_validation_split_index(is_ensemble: bool) -> None:
    """
    Test that retrieved cross validation split index is as expected, for single runs and ensembles.
    """
    run = fetch_run(workspace=get_default_workspace(),
                    run_recovery_id=DEFAULT_ENSEMBLE_RUN_RECOVERY_ID
                    if is_ensemble else DEFAULT_RUN_RECOVERY_ID)
    # check for offline run
    assert get_cross_validation_split_index(
        Run.get_context()) == DEFAULT_CROSS_VALIDATION_SPLIT_INDEX
    # check for online runs
    assert get_cross_validation_split_index(
        run) == DEFAULT_CROSS_VALIDATION_SPLIT_INDEX
    if is_ensemble:
        assert all([
            get_cross_validation_split_index(x) >
            DEFAULT_CROSS_VALIDATION_SPLIT_INDEX for x in fetch_child_runs(run)
        ])