Example #1
0
def fetch_child_runs(run: Run, status: Optional[str] = None,
                     expected_number_cross_validation_splits: int = 0) -> List[Run]:
    """
    Fetch child runs for the provided runs that have the provided AML status (or fetch all by default)
    and have a run_recovery_id tag value set (this is to ignore superfluous AML infrastructure platform runs).
    :param run: parent run to fetch child run from
    :param status: if provided, returns only child runs with this status
    :param expected_number_cross_validation_splits: when recovering child runs from AML hyperdrive
    sometimes the get_children function fails to retrieve all children. If the number of child runs
    retrieved by AML is lower than the expected number of splits, we try to retrieve them manually.
    """
    if is_ensemble_run(run):
        run_recovery_id = run.get_tags().get(RUN_RECOVERY_FROM_ID_KEY_NAME, None)
        if run_recovery_id:
            run = fetch_run(run.experiment.workspace, run_recovery_id)
        elif PARENT_RUN_CONTEXT:
            run = PARENT_RUN_CONTEXT
    children_runs = list(run.get_children(tags=RUN_RECOVERY_ID_KEY_NAME))
    if 0 < expected_number_cross_validation_splits != len(children_runs):
        logging.warning(
            f"The expected number of child runs was {expected_number_cross_validation_splits}."
            f"Fetched only: {len(children_runs)} runs. Now trying to fetch them manually.")
        run_ids_to_evaluate = [f"{create_run_recovery_id(run)}_{i}"
                               for i in range(expected_number_cross_validation_splits)]
        children_runs = [fetch_run(run.experiment.workspace, id) for id in run_ids_to_evaluate]
    if status is not None:
        children_runs = [child_run for child_run in children_runs if child_run.get_status() == status]
    return children_runs
Example #2
0
    def register_model_for_epoch(self, run_context: Run,
                                 checkpoint_handler: CheckpointHandler,
                                 best_epoch: int, best_epoch_dice: float,
                                 model_proc: ModelProcessing) -> None:

        checkpoint_path_and_epoch = checkpoint_handler.get_checkpoint_from_epoch(
            epoch=best_epoch)
        if not checkpoint_path_and_epoch or not checkpoint_path_and_epoch.checkpoint_paths:
            # No point continuing, since no checkpoints were found
            logging.warning(
                "Abandoning model registration - no valid checkpoint paths found"
            )
            return

        if not self.model_config.is_offline_run:
            split_index = run_context.get_tags().get(
                CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, None)
            if split_index == DEFAULT_CROSS_VALIDATION_SPLIT_INDEX:
                update_run_tags(
                    run_context, {
                        IS_ENSEMBLE_KEY_NAME:
                        model_proc == ModelProcessing.ENSEMBLE_CREATION
                    })
            elif PARENT_RUN_CONTEXT is not None:
                update_run_tags(
                    run_context,
                    {PARENT_RUN_ID_KEY_NAME: PARENT_RUN_CONTEXT.id})
        with logging_section(f"Registering {model_proc.value} model"):
            self.register_segmentation_model(
                run=run_context,
                best_epoch=best_epoch,
                best_epoch_dice=best_epoch_dice,
                checkpoint_paths=checkpoint_path_and_epoch.checkpoint_paths,
                model_proc=model_proc)
Example #3
0
def get_cross_validation_split_index(run: Run) -> int:
    """
    Gets the cross validation index from the run's tags or returns the default
    :param run: Run context from which to get index
    :return: The cross validation split index
    """
    if is_offline_run_context(run):
        return DEFAULT_CROSS_VALIDATION_SPLIT_INDEX
    else:
        return int(run.get_tags().get(CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, DEFAULT_CROSS_VALIDATION_SPLIT_INDEX))
Example #4
0
 def register_model_for_epoch(self, run_context: Run,
                              run_recovery: Optional[RunRecovery],
                              best_epoch: int, best_epoch_dice: float,
                              model_proc: ModelProcessing) -> None:
     checkpoint_paths = [self.model_config.get_path_to_checkpoint(best_epoch)] if not run_recovery \
         else run_recovery.get_checkpoint_paths(best_epoch)
     if not self.model_config.is_offline_run:
         split_index = run_context.get_tags().get(
             CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, None)
         if split_index == DEFAULT_CROSS_VALIDATION_SPLIT_INDEX:
             update_run_tags(
                 run_context, {
                     IS_ENSEMBLE_KEY_NAME:
                     model_proc == ModelProcessing.ENSEMBLE_CREATION
                 })
         elif PARENT_RUN_CONTEXT is not None:
             update_run_tags(
                 run_context,
                 {PARENT_RUN_ID_KEY_NAME: PARENT_RUN_CONTEXT.id})
     # Discard any checkpoint paths that do not exist - they will make registration fail. This can happen
     # when some child runs fail; it may still be worth registering the model.
     valid_checkpoint_paths = []
     for path in checkpoint_paths:
         if path.exists():
             valid_checkpoint_paths.append(path)
         else:
             logging.warning(
                 f"Discarding non-existent checkpoint path {path}")
     if not valid_checkpoint_paths:
         # No point continuing
         logging.warning(
             "Abandoning model registration - no valid checkpoint paths found"
         )
         return
     with logging_section(f"Registering {model_proc.value} model"):
         self.register_segmentation_model(
             run=run_context,
             best_epoch=best_epoch,
             best_epoch_dice=best_epoch_dice,
             checkpoint_paths=valid_checkpoint_paths,
             model_proc=model_proc)
def is_ensemble_run(run: Run) -> bool:
    """Checks if the run was an ensemble of multiple models"""
    return run.get_tags().get(IS_ENSEMBLE_KEY_NAME) == 'True'
def update_run_tags(run: Run, tags: Dict[str, Any]) -> None:
    """Updates tags for the given run with the provided dictionary"""
    run.set_tags({**run.get_tags(), **tags})
Example #7
0
    # sys.exit("Currently this model registration script can only run in "+
    #     "context of a parent pipeline.")
else:
    ws = run.experiment.workspace
    print("...getting arguments (model_name, training_step_name)")
    model_name = sys.argv[2]
    training_step_name = sys.argv[4]
    parentrun = run.parent

print("model_name:", model_name)
print("training_step_name:", training_step_name)

# The required metrics should be present in the parent run, the below condition has been included
# to show an alternative approach by getting those metrics from the prior training step directly.
training_run_id = None
tagsdict = parentrun.get_tags()
if (tagsdict.get("best_model")) != None:
    model_type = tagsdict['best_model']
    model_accuracy = float(tagsdict['accuracy'])
    training_run_id = parentrun.id
else:
    for step in parentrun.get_children():
        print("Outputs of step " + step.name)
        if step.name == training_step_name:
            tagsdict = step.get_tags()
            model_type = tagsdict['best_model']
            model_accuracy = float(tagsdict['accuracy'])
            training_run_id = step.id

if (training_run_id == None):
    sys.exit("Failed to retrieve model information from run.")