def register_model_for_epoch(self, run_context: Run, run_recovery: Optional[RunRecovery], best_epoch: int, best_epoch_dice: float, model_proc: ModelProcessing) -> None: checkpoint_paths = get_recovery_path_test(config=self.model_config, run_recovery=run_recovery, epoch=best_epoch) if not checkpoint_paths: # No point continuing, since no checkpoints were found logging.warning( "Abandoning model registration - no valid checkpoint paths found" ) return if not self.model_config.is_offline_run: split_index = run_context.get_tags().get( CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, None) if split_index == DEFAULT_CROSS_VALIDATION_SPLIT_INDEX: update_run_tags( run_context, { IS_ENSEMBLE_KEY_NAME: model_proc == ModelProcessing.ENSEMBLE_CREATION }) elif PARENT_RUN_CONTEXT is not None: update_run_tags( run_context, {PARENT_RUN_ID_KEY_NAME: PARENT_RUN_CONTEXT.id}) with logging_section(f"Registering {model_proc.value} model"): self.register_segmentation_model(run=run_context, best_epoch=best_epoch, best_epoch_dice=best_epoch_dice, checkpoint_paths=checkpoint_paths, model_proc=model_proc)
def create_inference_pipeline( config: ModelConfigBase, epoch: int, run_recovery: Optional[RunRecovery] = None ) -> Optional[InferencePipelineBase]: """ If multiple checkpoints are found in run_recovery then create EnsemblePipeline otherwise InferencePipeline. If no checkpoint files exist in the run recovery or current run checkpoint folder, None will be returned. :param config: Model related configs. :param epoch: The epoch for which to create pipeline for. :param run_recovery: RunRecovery data if applicable :return: FullImageInferencePipelineBase or ScalarInferencePipelineBase """ checkpoint_paths = get_recovery_path_test(config=config, run_recovery=run_recovery, epoch=epoch) if not checkpoint_paths: return None if len(checkpoint_paths) > 1: if config.is_segmentation_model: assert isinstance(config, SegmentationModelBase) return EnsemblePipeline.create_from_checkpoints( path_to_checkpoints=checkpoint_paths, model_config=config) elif config.is_scalar_model: assert isinstance(config, ScalarModelBase) return ScalarEnsemblePipeline.create_from_checkpoint( paths_to_checkpoint=checkpoint_paths, config=config) else: raise NotImplementedError( "Cannot create inference pipeline for unknown model type") if len(checkpoint_paths) == 1: if config.is_segmentation_model: assert isinstance(config, SegmentationModelBase) return InferencePipeline.create_from_checkpoint( path_to_checkpoint=checkpoint_paths[0], model_config=config) elif config.is_scalar_model: assert isinstance(config, ScalarModelBase) return ScalarInferencePipeline.create_from_checkpoint( path_to_checkpoint=checkpoint_paths[0], config=config) else: raise NotImplementedError( "Cannot create ensemble pipeline for unknown model type") return None
def create_inference_pipeline( config: ModelConfigBase, epoch: int, run_recovery: Optional[RunRecovery] = None ) -> Optional[InferencePipelineBase]: """ If multiple checkpoints are found in run_recovery then create EnsemblePipeline otherwise InferencePipeline. :param config: Model related configs. :param epoch: The epoch for which to create pipeline for. :param run_recovery: RunRecovery data if applicable :return: FullImageInferencePipelineBase or ScalarInferencePipelineBase """ checkpoint_paths = get_recovery_path_test( config=config, run_recovery=run_recovery, is_mean_teacher=config.compute_mean_teacher_model, epoch=epoch) if not checkpoint_paths: return None return create_pipeline_from_checkpoint_paths(config, checkpoint_paths)