Esempio n. 1
0
    def register_model_for_epoch(self, run_context: Run,
                                 run_recovery: Optional[RunRecovery],
                                 best_epoch: int, best_epoch_dice: float,
                                 model_proc: ModelProcessing) -> None:

        checkpoint_paths = get_recovery_path_test(config=self.model_config,
                                                  run_recovery=run_recovery,
                                                  epoch=best_epoch)
        if not checkpoint_paths:
            # No point continuing, since no checkpoints were found
            logging.warning(
                "Abandoning model registration - no valid checkpoint paths found"
            )
            return

        if not self.model_config.is_offline_run:
            split_index = run_context.get_tags().get(
                CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, None)
            if split_index == DEFAULT_CROSS_VALIDATION_SPLIT_INDEX:
                update_run_tags(
                    run_context, {
                        IS_ENSEMBLE_KEY_NAME:
                        model_proc == ModelProcessing.ENSEMBLE_CREATION
                    })
            elif PARENT_RUN_CONTEXT is not None:
                update_run_tags(
                    run_context,
                    {PARENT_RUN_ID_KEY_NAME: PARENT_RUN_CONTEXT.id})
        with logging_section(f"Registering {model_proc.value} model"):
            self.register_segmentation_model(run=run_context,
                                             best_epoch=best_epoch,
                                             best_epoch_dice=best_epoch_dice,
                                             checkpoint_paths=checkpoint_paths,
                                             model_proc=model_proc)
def create_inference_pipeline(
    config: ModelConfigBase,
    epoch: int,
    run_recovery: Optional[RunRecovery] = None
) -> Optional[InferencePipelineBase]:
    """
    If multiple checkpoints are found in run_recovery then create EnsemblePipeline otherwise InferencePipeline.
    If no checkpoint files exist in the run recovery or current run checkpoint folder, None will be returned.
    :param config: Model related configs.
    :param epoch: The epoch for which to create pipeline for.
    :param run_recovery: RunRecovery data if applicable
    :return: FullImageInferencePipelineBase or ScalarInferencePipelineBase
    """
    checkpoint_paths = get_recovery_path_test(config=config,
                                              run_recovery=run_recovery,
                                              epoch=epoch)
    if not checkpoint_paths:
        return None

    if len(checkpoint_paths) > 1:
        if config.is_segmentation_model:
            assert isinstance(config, SegmentationModelBase)
            return EnsemblePipeline.create_from_checkpoints(
                path_to_checkpoints=checkpoint_paths, model_config=config)
        elif config.is_scalar_model:
            assert isinstance(config, ScalarModelBase)
            return ScalarEnsemblePipeline.create_from_checkpoint(
                paths_to_checkpoint=checkpoint_paths, config=config)
        else:
            raise NotImplementedError(
                "Cannot create inference pipeline for unknown model type")
    if len(checkpoint_paths) == 1:
        if config.is_segmentation_model:
            assert isinstance(config, SegmentationModelBase)
            return InferencePipeline.create_from_checkpoint(
                path_to_checkpoint=checkpoint_paths[0], model_config=config)
        elif config.is_scalar_model:
            assert isinstance(config, ScalarModelBase)
            return ScalarInferencePipeline.create_from_checkpoint(
                path_to_checkpoint=checkpoint_paths[0], config=config)
        else:
            raise NotImplementedError(
                "Cannot create ensemble pipeline for unknown model type")
    return None
Esempio n. 3
0
def create_inference_pipeline(
    config: ModelConfigBase,
    epoch: int,
    run_recovery: Optional[RunRecovery] = None
) -> Optional[InferencePipelineBase]:
    """
    If multiple checkpoints are found in run_recovery then create EnsemblePipeline otherwise InferencePipeline.
    :param config: Model related configs.
    :param epoch: The epoch for which to create pipeline for.
    :param run_recovery: RunRecovery data if applicable
    :return: FullImageInferencePipelineBase or ScalarInferencePipelineBase
    """
    checkpoint_paths = get_recovery_path_test(
        config=config,
        run_recovery=run_recovery,
        is_mean_teacher=config.compute_mean_teacher_model,
        epoch=epoch)
    if not checkpoint_paths:
        return None

    return create_pipeline_from_checkpoint_paths(config, checkpoint_paths)