コード例 #1
0
def test_is_cross_validation_child_run_single_run() -> None:
    """
    Test that cross validation child runs are identified correctly. A single run should not be identified as a
    cross validation run.
    """
    run = get_most_recent_run()
    # check for offline run
    assert not is_cross_validation_child_run(Run.get_context())
    # check for online runs
    assert not is_cross_validation_child_run(run)
コード例 #2
0
def test_is_cross_validation_child_run_ensemble_run() -> None:
    """
    Test that cross validation child runs are identified correctly.
    """
    # check for offline run
    assert not is_cross_validation_child_run(Run.get_context())
    # check for online runs
    run = get_most_recent_run(
        fallback_run_id_for_local_execution=FALLBACK_ENSEMBLE_RUN)
    assert not is_cross_validation_child_run(run)
    assert all(
        [is_cross_validation_child_run(x) for x in fetch_child_runs(run)])
コード例 #3
0
    def download_checkpoints_from_recovery_run(azure_config: AzureConfig,
                                               config: DeepLearningConfig,
                                               run_context: Optional[Run] = None) -> RunRecovery:
        """
        Downloads checkpoints of run corresponding to the run_recovery_id in azure_config, and any
        checkpoints of the child runs if they exist.

        :param azure_config: Azure related configs.
        :param config: Model related configs.
        :param run_context: Context of the current run (will be used to find the target AML workspace)
        :return:RunRecovery
        """
        run_context = run_context or RUN_CONTEXT
        workspace = azure_config.get_workspace()

        # Find the run to recover in AML workspace
        if not azure_config.run_recovery_id:
            raise ValueError("A valid run_recovery_id is required to download recovery checkpoints, found None")

        run_to_recover = fetch_run(workspace, azure_config.run_recovery_id.strip())
        # Handle recovery of a HyperDrive cross validation run (from within a successor HyperDrive run,
        # not in ensemble creation). In this case, run_recovery_id refers to the parent prior run, so we
        # need to set run_to_recover to the child of that run whose split index is the same as that of
        # the current (child) run.
        if is_cross_validation_child_run(run_context):
            run_to_recover = next(x for x in fetch_child_runs(run_to_recover) if
                                  get_cross_validation_split_index(x) == get_cross_validation_split_index(run_context))

        return RunRecovery.download_checkpoints_from_run(config, run_to_recover)
コード例 #4
0
 def wait_until_cross_val_splits_are_ready_for_aggregation(self) -> bool:
     """
     Checks if all child runs (except the current run) of the current run's parent are completed or failed.
     If this is the case, then we can aggregate the results of the other runs before terminating this run.
     :return: whether we need to wait, i.e. whether some runs are still pending.
     """
     if (not self.model_config.is_offline_run) \
             and (azure_util.is_cross_validation_child_run(RUN_CONTEXT)):
         n_splits = self.model_config.get_total_number_of_cross_validation_runs(
         )
         child_runs = azure_util.fetch_child_runs(
             PARENT_RUN_CONTEXT,
             expected_number_cross_validation_splits=n_splits)
         pending_runs = [
             x.id for x in child_runs if (x.id != RUN_CONTEXT.id) and
             (x.get_status() not in [RunStatus.COMPLETED, RunStatus.FAILED])
         ]
         should_wait = len(pending_runs) > 0
         if should_wait:
             logging.info(
                 f"Waiting for sibling run(s) to finish: {pending_runs}")
         return should_wait
     else:
         raise NotImplementedError(
             "cross_val_splits_are_ready_for_aggregation is implemented for online "
             "cross validation runs only")
コード例 #5
0
 def are_sibling_runs_finished(self) -> bool:
     """
     Checks if all child runs (except the current run) of the current run's parent are completed or failed.
     :return: True if all sibling runs of the current run have finished (they either completed successfully,
     or failed). False if any of them is still pending (running or queued).
     """
     if (not self.model_config.is_offline_run) \
             and (azure_util.is_cross_validation_child_run(RUN_CONTEXT)):
         n_splits = self.model_config.get_total_number_of_cross_validation_runs(
         )
         child_runs = azure_util.fetch_child_runs(
             PARENT_RUN_CONTEXT,
             expected_number_cross_validation_splits=n_splits)
         pending_runs = [
             x.id for x in child_runs if (x.id != RUN_CONTEXT.id) and
             (x.get_status() not in [RunStatus.COMPLETED, RunStatus.FAILED])
         ]
         all_runs_finished = len(pending_runs) == 0
         if not all_runs_finished:
             logging.info(
                 f"Waiting for sibling run(s) to finish: {pending_runs}")
         return all_runs_finished
     else:
         raise NotImplementedError(
             "are_sibling_runs_finished only works for cross validation runs in AzureML."
         )
コード例 #6
0
def test_is_cross_validation_child_run(is_ensemble: bool,
                                       is_numeric: bool) -> None:
    """
    Test that cross validation child runs are identified correctly.
    """
    if is_ensemble:
        rid = DEFAULT_ENSEMBLE_RUN_RECOVERY_ID_NUMERIC if is_numeric else DEFAULT_ENSEMBLE_RUN_RECOVERY_ID
    else:
        rid = DEFAULT_RUN_RECOVERY_ID_NUMERIC if is_numeric else DEFAULT_RUN_RECOVERY_ID
    run = fetch_run(workspace=get_default_workspace(), run_recovery_id=rid)
    # check for offline run
    assert not is_cross_validation_child_run(Run.get_context())
    # check for online runs
    assert not is_cross_validation_child_run(run)
    if is_ensemble:
        assert all(
            [is_cross_validation_child_run(x) for x in fetch_child_runs(run)])