Ejemplo n.º 1
0
    def run_inference_and_register_model(self, checkpoint_handler: CheckpointHandler,
                                         model_proc: ModelProcessing) -> None:
        """
        Run inference as required, and register the model, but not necessarily in that order:
        if we can identify the epoch to register at without running inference, we register first.
        :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
        :param model_proc: whether we are running an ensemble model from within a child run with index 0. If we are,
        then outputs will be written to OTHER_RUNS/ENSEMBLE under the main outputs directory.
        """

        if self.should_register_model():
            checkpoint_paths = checkpoint_handler.get_checkpoints_to_test()
            if not checkpoint_paths:
                raise ValueError("Model registration failed: No checkpoints found")

            model_description = "Registering model."
            checkpoint_paths = checkpoint_paths
            self.register_model(checkpoint_paths, model_description, model_proc)

        if not self.azure_config.only_register_model:
            # run full image inference on existing or newly trained model on the training, and testing set
            test_metrics, val_metrics, _ = self.model_inference_train_and_test(checkpoint_handler=checkpoint_handler,
                                                                               model_proc=model_proc)

            self.try_compare_scores_against_baselines(model_proc)
Ejemplo n.º 2
0
def segmentation_model_test(
    config: SegmentationModelBase,
    data_split: ModelExecutionMode,
    checkpoint_handler: CheckpointHandler,
    model_proc: ModelProcessing = ModelProcessing.DEFAULT
) -> InferenceMetricsForSegmentation:
    """
    The main testing loop for segmentation models.
    It loads the model and datasets, then proceeds to test the model for all requested checkpoints.
    :param config: The arguments object which has a valid random seed attribute.
    :param data_split: Indicates which of the 3 sets (training, test, or validation) is being processed.
    :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
    :param model_proc: whether we are testing an ensemble or single model
    :return: InferenceMetric object that contains metrics related for all of the checkpoint epochs.
    """
    results: Dict[int, float] = {}
    checkpoints_to_test = checkpoint_handler.get_checkpoints_to_test()

    if not checkpoints_to_test:
        raise ValueError(
            "There were no checkpoints available for model testing.")

    for checkpoint_paths_and_epoch in checkpoints_to_test:
        epoch = checkpoint_paths_and_epoch.epoch
        epoch_results_folder = config.outputs_folder / get_epoch_results_path(
            epoch, data_split, model_proc)
        # save the datasets.csv used
        config.write_dataset_files(root=epoch_results_folder)
        epoch_and_split = "epoch {} {} set".format(epoch, data_split.value)
        epoch_dice_per_image = segmentation_model_test_epoch(
            config=copy.deepcopy(config),
            data_split=data_split,
            checkpoint_paths=checkpoint_paths_and_epoch.checkpoint_paths,
            results_folder=epoch_results_folder,
            epoch_and_split=epoch_and_split)
        if epoch_dice_per_image is None:
            logging.warning(
                "There is no checkpoint file for epoch {}".format(epoch))
        else:
            epoch_average_dice: float = np.mean(
                epoch_dice_per_image) if len(epoch_dice_per_image) > 0 else 0
            results[epoch] = epoch_average_dice
            logging.info("Epoch: {:3} | Mean Dice: {:4f}".format(
                epoch, epoch_average_dice))
            if model_proc == ModelProcessing.ENSEMBLE_CREATION:
                # For the upload, we want the path without the "OTHER_RUNS/ENSEMBLE" prefix.
                name = str(
                    get_epoch_results_path(epoch, data_split,
                                           ModelProcessing.DEFAULT))
                PARENT_RUN_CONTEXT.upload_folder(
                    name=name, path=str(epoch_results_folder))
    if len(results) == 0:
        raise ValueError(
            "There was no single checkpoint file available for model testing.")
    return InferenceMetricsForSegmentation(data_split=data_split,
                                           epochs=results)
Ejemplo n.º 3
0
def classification_model_test(config: ScalarModelBase,
                              data_split: ModelExecutionMode,
                              checkpoint_handler: CheckpointHandler,
                              model_proc: ModelProcessing) -> InferenceMetricsForClassification:
    """
    The main testing loop for classification models. It runs a loop over all epochs for which testing should be done.
    It loads the model and datasets, then proceeds to test the model for all requested checkpoints.
    :param config: The model configuration.
    :param data_split: The name of the folder to store the results inside each epoch folder in the outputs_dir,
                       used mainly in model evaluation using different dataset splits.
    :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
    :param model_proc: whether we are testing an ensemble or single model
    :return: InferenceMetricsForClassification object that contains metrics related for all of the checkpoint epochs.
    """

    def test_epoch(checkpoint_paths: List[Path]) -> Optional[MetricsDict]:
        pipeline = create_inference_pipeline(config=config,
                                             checkpoint_paths=checkpoint_paths)

        if pipeline is None:
            return None

        # for mypy
        assert isinstance(pipeline, ScalarInferencePipelineBase)

        ml_util.set_random_seed(config.get_effective_random_seed(), "Model Testing")
        ds = config.get_torch_dataset_for_inference(data_split).as_data_loader(
            shuffle=False,
            batch_size=1,
            num_dataload_workers=0
        )

        logging.info(f"Starting to evaluate model on {data_split.value} set.")
        metrics_dict = create_metrics_dict_for_scalar_models(config)
        for sample in ds:
            result = pipeline.predict(sample)
            model_output = result.posteriors
            label = result.labels.to(device=model_output.device)
            sample_id = result.subject_ids[0]
            compute_scalar_metrics(metrics_dict,
                                   subject_ids=[sample_id],
                                   model_output=model_output,
                                   labels=label,
                                   loss_type=config.loss_type)
            logging.debug(f"Example {sample_id}: {metrics_dict.to_string()}")

        average = metrics_dict.average(across_hues=False)
        logging.info(average.to_string())

        return metrics_dict

    checkpoints_to_test = checkpoint_handler.get_checkpoints_to_test()

    if not checkpoints_to_test:
        raise ValueError("There were no checkpoints available for model testing.")

    result = test_epoch(checkpoint_paths=checkpoints_to_test)
    if result is None:
        raise ValueError("There was no single checkpoint file available for model testing.")
    else:
        if isinstance(result, ScalarMetricsDict):
            results_folder = config.outputs_folder / get_epoch_results_path(data_split, model_proc)
            csv_file = results_folder / SUBJECT_METRICS_FILE_NAME

            logging.info(f"Writing {data_split.value} metrics to file {str(csv_file)}")

            # If we are running inference after a training run, the validation set metrics may have been written
            # during train time. If this is not the case, or we are running on the test set, create the metrics
            # file.
            if not csv_file.exists():
                os.makedirs(str(results_folder), exist_ok=False)
                df_logger = DataframeLogger(csv_file)

                # cross validation split index not relevant during test time
                result.store_metrics_per_subject(df_logger=df_logger,
                                                 mode=data_split)
                # write to disk
                df_logger.flush()

    return InferenceMetricsForClassification(metrics=result)
Ejemplo n.º 4
0
def classification_model_test(config: ScalarModelBase,
                              data_split: ModelExecutionMode,
                              checkpoint_handler: CheckpointHandler,
                              model_proc: ModelProcessing,
                              cross_val_split_index: int) -> InferenceMetricsForClassification:
    """
    The main testing loop for classification models. It runs a loop over all epochs for which testing should be done.
    It loads the model and datasets, then proceeds to test the model for all requested checkpoints.
    :param config: The model configuration.
    :param data_split: The name of the folder to store the results inside each epoch folder in the outputs_dir,
                       used mainly in model evaluation using different dataset splits.
    :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
    :param model_proc: whether we are testing an ensemble or single model
    :return: InferenceMetricsForClassification object that contains metrics related for all of the checkpoint epochs.
    """
    posthoc_label_transform = config.get_posthoc_label_transform()

    checkpoint_paths = checkpoint_handler.get_checkpoints_to_test()
    if not checkpoint_paths:
        raise ValueError("There were no checkpoints available for model testing.")

    pipeline = create_inference_pipeline(config=config,
                                         checkpoint_paths=checkpoint_paths)
    if pipeline is None:
        raise ValueError("Inference pipeline could not be created.")

    # for mypy
    assert isinstance(pipeline, ScalarInferencePipelineBase)

    ml_util.set_random_seed(config.get_effective_random_seed(), "Model Testing")
    ds = config.get_torch_dataset_for_inference(data_split).as_data_loader(
        shuffle=False,
        batch_size=1,
        num_dataload_workers=0
    )

    logging.info(f"Starting to evaluate model on {data_split.value} set.")
    results_folder = config.outputs_folder / get_best_epoch_results_path(data_split, model_proc)
    os.makedirs(str(results_folder), exist_ok=True)
    metrics_dict = create_metrics_dict_for_scalar_models(config)
    if not isinstance(config, SequenceModelBase):
        output_logger: Optional[DataframeLogger] = DataframeLogger(csv_path=results_folder / MODEL_OUTPUT_CSV)
    else:
        output_logger = None

    for sample in ds:
        result = pipeline.predict(sample)
        model_output = result.posteriors
        label = result.labels.to(device=model_output.device)
        label = posthoc_label_transform(label)
        sample_id = result.subject_ids[0]
        if output_logger:
            for i in range(len(config.target_names)):
                output_logger.add_record({LoggingColumns.Patient.value: sample_id,
                                          LoggingColumns.Hue.value: config.target_names[i],
                                          LoggingColumns.Label.value: label[0][i].item(),
                                          LoggingColumns.ModelOutput.value: model_output[0][i].item(),
                                          LoggingColumns.CrossValidationSplitIndex.value: cross_val_split_index})

        compute_scalar_metrics(metrics_dict,
                               subject_ids=[sample_id],
                               model_output=model_output,
                               labels=label,
                               loss_type=config.loss_type)
        logging.debug(f"Example {sample_id}: {metrics_dict.to_string()}")

    average = metrics_dict.average(across_hues=False)
    logging.info(average.to_string())

    if isinstance(metrics_dict, ScalarMetricsDict):
        csv_file = results_folder / SUBJECT_METRICS_FILE_NAME

        logging.info(f"Writing {data_split.value} metrics to file {str(csv_file)}")

        # If we are running inference after a training run, the validation set metrics may have been written
        # during train time. If this is not the case, or we are running on the test set, create the metrics
        # file.
        if not csv_file.exists():
            df_logger = DataframeLogger(csv_file)
            # For test if ensemble split should be default, else record which fold produced this prediction
            cv_index = DEFAULT_CROSS_VALIDATION_SPLIT_INDEX if model_proc == ModelProcessing.ENSEMBLE_CREATION \
                else cross_val_split_index
            metrics_dict.store_metrics_per_subject(df_logger=df_logger,
                                                   mode=data_split,
                                                   cross_validation_split_index=cv_index,
                                                   epoch=BEST_EPOCH_FOLDER_NAME)
            # write to disk
            df_logger.flush()

    if output_logger:
        output_logger.flush()

    return InferenceMetricsForClassification(metrics=metrics_dict)