def test_create_from_checkpoint_non_ensemble( test_output_dirs: OutputFolderForTests) -> None: config = ClassificationModelForTesting() # when checkpoint does not exist, return None path_to_checkpoint = test_output_dirs.root_dir / "foo.ckpt" inference_pipeline = ScalarInferencePipeline.create_from_checkpoint( path_to_checkpoint, config) assert inference_pipeline is None create_model_and_store_checkpoint(config, path_to_checkpoint) inference_pipeline = ScalarInferencePipeline.create_from_checkpoint( path_to_checkpoint, config) assert isinstance(inference_pipeline, ScalarInferencePipeline)
def test_create_from_checkpoint_non_ensemble() -> None: config = ClassificationModelForTesting() # when checkpoint does not exist, return None checkpoint_folder = "classification_data_generated_random/checkpoints/non_exist.pth.tar" path_to_checkpoint = full_ml_test_data_path(checkpoint_folder) inference_pipeline = ScalarInferencePipeline.create_from_checkpoint(path_to_checkpoint, config) assert inference_pipeline is None checkpoint_folder = "classification_data_generated_random/checkpoints/1_checkpoint.pth.tar" path_to_checkpoint = full_ml_test_data_path(checkpoint_folder) inference_pipeline = ScalarInferencePipeline.create_from_checkpoint(path_to_checkpoint, config) assert isinstance(inference_pipeline, ScalarInferencePipeline) assert inference_pipeline.epoch == 1
def create_pipeline_from_checkpoint_paths( config: ModelConfigBase, checkpoint_paths: List[Path]) -> Optional[InferencePipelineBase]: """ Attempt to create a pipeline from the provided checkpoint paths. If the files referred to by the paths do not exist, or if there are no paths, None will be returned. """ if len(checkpoint_paths) > 1: if config.is_segmentation_model: assert isinstance(config, SegmentationModelBase) return EnsemblePipeline.create_from_checkpoints( path_to_checkpoints=checkpoint_paths, model_config=config) elif config.is_scalar_model: assert isinstance(config, ScalarModelBase) return ScalarEnsemblePipeline.create_from_checkpoint( paths_to_checkpoint=checkpoint_paths, config=config) else: raise NotImplementedError( "Cannot create inference pipeline for unknown model type") if len(checkpoint_paths) == 1: if config.is_segmentation_model: assert isinstance(config, SegmentationModelBase) return InferencePipeline.create_from_checkpoint( path_to_checkpoint=checkpoint_paths[0], model_config=config) elif config.is_scalar_model: assert isinstance(config, ScalarModelBase) return ScalarInferencePipeline.create_from_checkpoint( path_to_checkpoint=checkpoint_paths[0], config=config) else: raise NotImplementedError( "Cannot create ensemble pipeline for unknown model type") return None
def create_inference_pipeline(config: ModelConfigBase, checkpoint_paths: List[Path]) -> Optional[InferencePipelineBase]: """ If multiple checkpoints are found in run_recovery then create EnsemblePipeline otherwise InferencePipeline. If no checkpoint files exist in the run recovery or current run checkpoint folder, None will be returned. :param config: Model related configs. :param epoch: The epoch for which to create pipeline for. :param run_recovery: RunRecovery data if applicable :return: FullImageInferencePipelineBase or ScalarInferencePipelineBase """ if not checkpoint_paths: return None if len(checkpoint_paths) > 1: if config.is_segmentation_model: assert isinstance(config, SegmentationModelBase) return EnsemblePipeline.create_from_checkpoints(path_to_checkpoints=checkpoint_paths, model_config=config) elif config.is_scalar_model: assert isinstance(config, ScalarModelBase) return ScalarEnsemblePipeline.create_from_checkpoint(paths_to_checkpoint=checkpoint_paths, config=config) else: raise NotImplementedError("Cannot create inference pipeline for unknown model type") if len(checkpoint_paths) == 1: if config.is_segmentation_model: assert isinstance(config, SegmentationModelBase) return InferencePipeline.create_from_checkpoint(path_to_checkpoint=checkpoint_paths[0], model_config=config) elif config.is_scalar_model: assert isinstance(config, ScalarModelBase) return ScalarInferencePipeline.create_from_checkpoint(path_to_checkpoint=checkpoint_paths[0], config=config) else: raise NotImplementedError("Cannot create ensemble pipeline for unknown model type") return None