コード例 #1
0
def test_create_from_checkpoint_non_ensemble(
        test_output_dirs: OutputFolderForTests) -> None:
    config = ClassificationModelForTesting()

    # when checkpoint does not exist, return None
    path_to_checkpoint = test_output_dirs.root_dir / "foo.ckpt"
    inference_pipeline = ScalarInferencePipeline.create_from_checkpoint(
        path_to_checkpoint, config)
    assert inference_pipeline is None

    create_model_and_store_checkpoint(config, path_to_checkpoint)
    inference_pipeline = ScalarInferencePipeline.create_from_checkpoint(
        path_to_checkpoint, config)
    assert isinstance(inference_pipeline, ScalarInferencePipeline)
コード例 #2
0
def test_create_from_checkpoint_non_ensemble() -> None:
    config = ClassificationModelForTesting()

    # when checkpoint does not exist, return None
    checkpoint_folder = "classification_data_generated_random/checkpoints/non_exist.pth.tar"
    path_to_checkpoint = full_ml_test_data_path(checkpoint_folder)
    inference_pipeline = ScalarInferencePipeline.create_from_checkpoint(path_to_checkpoint, config)
    assert inference_pipeline is None

    checkpoint_folder = "classification_data_generated_random/checkpoints/1_checkpoint.pth.tar"
    path_to_checkpoint = full_ml_test_data_path(checkpoint_folder)
    inference_pipeline = ScalarInferencePipeline.create_from_checkpoint(path_to_checkpoint, config)
    assert isinstance(inference_pipeline, ScalarInferencePipeline)
    assert inference_pipeline.epoch == 1
コード例 #3
0
def create_pipeline_from_checkpoint_paths(
        config: ModelConfigBase,
        checkpoint_paths: List[Path]) -> Optional[InferencePipelineBase]:
    """
    Attempt to create a pipeline from the provided checkpoint paths. If the files referred to by the paths
    do not exist, or if there are no paths, None will be returned.
    """
    if len(checkpoint_paths) > 1:
        if config.is_segmentation_model:
            assert isinstance(config, SegmentationModelBase)
            return EnsemblePipeline.create_from_checkpoints(
                path_to_checkpoints=checkpoint_paths, model_config=config)
        elif config.is_scalar_model:
            assert isinstance(config, ScalarModelBase)
            return ScalarEnsemblePipeline.create_from_checkpoint(
                paths_to_checkpoint=checkpoint_paths, config=config)
        else:
            raise NotImplementedError(
                "Cannot create inference pipeline for unknown model type")
    if len(checkpoint_paths) == 1:
        if config.is_segmentation_model:
            assert isinstance(config, SegmentationModelBase)
            return InferencePipeline.create_from_checkpoint(
                path_to_checkpoint=checkpoint_paths[0], model_config=config)
        elif config.is_scalar_model:
            assert isinstance(config, ScalarModelBase)
            return ScalarInferencePipeline.create_from_checkpoint(
                path_to_checkpoint=checkpoint_paths[0], config=config)
        else:
            raise NotImplementedError(
                "Cannot create ensemble pipeline for unknown model type")
    return None
コード例 #4
0
def create_inference_pipeline(config: ModelConfigBase,
                              checkpoint_paths: List[Path]) -> Optional[InferencePipelineBase]:
    """
    If multiple checkpoints are found in run_recovery then create EnsemblePipeline otherwise InferencePipeline.
    If no checkpoint files exist in the run recovery or current run checkpoint folder, None will be returned.
    :param config: Model related configs.
    :param epoch: The epoch for which to create pipeline for.
    :param run_recovery: RunRecovery data if applicable
    :return: FullImageInferencePipelineBase or ScalarInferencePipelineBase
    """
    if not checkpoint_paths:
        return None

    if len(checkpoint_paths) > 1:
        if config.is_segmentation_model:
            assert isinstance(config, SegmentationModelBase)
            return EnsemblePipeline.create_from_checkpoints(path_to_checkpoints=checkpoint_paths, model_config=config)
        elif config.is_scalar_model:
            assert isinstance(config, ScalarModelBase)
            return ScalarEnsemblePipeline.create_from_checkpoint(paths_to_checkpoint=checkpoint_paths, config=config)
        else:
            raise NotImplementedError("Cannot create inference pipeline for unknown model type")
    if len(checkpoint_paths) == 1:
        if config.is_segmentation_model:
            assert isinstance(config, SegmentationModelBase)
            return InferencePipeline.create_from_checkpoint(path_to_checkpoint=checkpoint_paths[0],
                                                            model_config=config)
        elif config.is_scalar_model:
            assert isinstance(config, ScalarModelBase)
            return ScalarInferencePipeline.create_from_checkpoint(path_to_checkpoint=checkpoint_paths[0],
                                                                  config=config)
        else:
            raise NotImplementedError("Cannot create ensemble pipeline for unknown model type")
    return None