def run_inference_and_register_model(self, checkpoint_handler: CheckpointHandler, model_proc: ModelProcessing) -> None: """ Run inference as required, and register the model, but not necessarily in that order: if we can identify the epoch to register at without running inference, we register first. :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization :param model_proc: whether we are running an ensemble model from within a child run with index 0. If we are, then outputs will be written to OTHER_RUNS/ENSEMBLE under the main outputs directory. """ if self.should_register_model(): checkpoint_paths = checkpoint_handler.get_checkpoints_to_test() if not checkpoint_paths: raise ValueError("Model registration failed: No checkpoints found") model_description = "Registering model." checkpoint_paths = checkpoint_paths self.register_model(checkpoint_paths, model_description, model_proc) if not self.azure_config.only_register_model: # run full image inference on existing or newly trained model on the training, and testing set test_metrics, val_metrics, _ = self.model_inference_train_and_test(checkpoint_handler=checkpoint_handler, model_proc=model_proc) self.try_compare_scores_against_baselines(model_proc)
def segmentation_model_test( config: SegmentationModelBase, data_split: ModelExecutionMode, checkpoint_handler: CheckpointHandler, model_proc: ModelProcessing = ModelProcessing.DEFAULT ) -> InferenceMetricsForSegmentation: """ The main testing loop for segmentation models. It loads the model and datasets, then proceeds to test the model for all requested checkpoints. :param config: The arguments object which has a valid random seed attribute. :param data_split: Indicates which of the 3 sets (training, test, or validation) is being processed. :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization :param model_proc: whether we are testing an ensemble or single model :return: InferenceMetric object that contains metrics related for all of the checkpoint epochs. """ results: Dict[int, float] = {} checkpoints_to_test = checkpoint_handler.get_checkpoints_to_test() if not checkpoints_to_test: raise ValueError( "There were no checkpoints available for model testing.") for checkpoint_paths_and_epoch in checkpoints_to_test: epoch = checkpoint_paths_and_epoch.epoch epoch_results_folder = config.outputs_folder / get_epoch_results_path( epoch, data_split, model_proc) # save the datasets.csv used config.write_dataset_files(root=epoch_results_folder) epoch_and_split = "epoch {} {} set".format(epoch, data_split.value) epoch_dice_per_image = segmentation_model_test_epoch( config=copy.deepcopy(config), data_split=data_split, checkpoint_paths=checkpoint_paths_and_epoch.checkpoint_paths, results_folder=epoch_results_folder, epoch_and_split=epoch_and_split) if epoch_dice_per_image is None: logging.warning( "There is no checkpoint file for epoch {}".format(epoch)) else: epoch_average_dice: float = np.mean( epoch_dice_per_image) if len(epoch_dice_per_image) > 0 else 0 results[epoch] = epoch_average_dice logging.info("Epoch: {:3} | Mean Dice: {:4f}".format( epoch, epoch_average_dice)) if model_proc == ModelProcessing.ENSEMBLE_CREATION: # For the upload, we want the path without the "OTHER_RUNS/ENSEMBLE" prefix. name = str( get_epoch_results_path(epoch, data_split, ModelProcessing.DEFAULT)) PARENT_RUN_CONTEXT.upload_folder( name=name, path=str(epoch_results_folder)) if len(results) == 0: raise ValueError( "There was no single checkpoint file available for model testing.") return InferenceMetricsForSegmentation(data_split=data_split, epochs=results)
def classification_model_test(config: ScalarModelBase, data_split: ModelExecutionMode, checkpoint_handler: CheckpointHandler, model_proc: ModelProcessing) -> InferenceMetricsForClassification: """ The main testing loop for classification models. It runs a loop over all epochs for which testing should be done. It loads the model and datasets, then proceeds to test the model for all requested checkpoints. :param config: The model configuration. :param data_split: The name of the folder to store the results inside each epoch folder in the outputs_dir, used mainly in model evaluation using different dataset splits. :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization :param model_proc: whether we are testing an ensemble or single model :return: InferenceMetricsForClassification object that contains metrics related for all of the checkpoint epochs. """ def test_epoch(checkpoint_paths: List[Path]) -> Optional[MetricsDict]: pipeline = create_inference_pipeline(config=config, checkpoint_paths=checkpoint_paths) if pipeline is None: return None # for mypy assert isinstance(pipeline, ScalarInferencePipelineBase) ml_util.set_random_seed(config.get_effective_random_seed(), "Model Testing") ds = config.get_torch_dataset_for_inference(data_split).as_data_loader( shuffle=False, batch_size=1, num_dataload_workers=0 ) logging.info(f"Starting to evaluate model on {data_split.value} set.") metrics_dict = create_metrics_dict_for_scalar_models(config) for sample in ds: result = pipeline.predict(sample) model_output = result.posteriors label = result.labels.to(device=model_output.device) sample_id = result.subject_ids[0] compute_scalar_metrics(metrics_dict, subject_ids=[sample_id], model_output=model_output, labels=label, loss_type=config.loss_type) logging.debug(f"Example {sample_id}: {metrics_dict.to_string()}") average = metrics_dict.average(across_hues=False) logging.info(average.to_string()) return metrics_dict checkpoints_to_test = checkpoint_handler.get_checkpoints_to_test() if not checkpoints_to_test: raise ValueError("There were no checkpoints available for model testing.") result = test_epoch(checkpoint_paths=checkpoints_to_test) if result is None: raise ValueError("There was no single checkpoint file available for model testing.") else: if isinstance(result, ScalarMetricsDict): results_folder = config.outputs_folder / get_epoch_results_path(data_split, model_proc) csv_file = results_folder / SUBJECT_METRICS_FILE_NAME logging.info(f"Writing {data_split.value} metrics to file {str(csv_file)}") # If we are running inference after a training run, the validation set metrics may have been written # during train time. If this is not the case, or we are running on the test set, create the metrics # file. if not csv_file.exists(): os.makedirs(str(results_folder), exist_ok=False) df_logger = DataframeLogger(csv_file) # cross validation split index not relevant during test time result.store_metrics_per_subject(df_logger=df_logger, mode=data_split) # write to disk df_logger.flush() return InferenceMetricsForClassification(metrics=result)
def classification_model_test(config: ScalarModelBase, data_split: ModelExecutionMode, checkpoint_handler: CheckpointHandler, model_proc: ModelProcessing, cross_val_split_index: int) -> InferenceMetricsForClassification: """ The main testing loop for classification models. It runs a loop over all epochs for which testing should be done. It loads the model and datasets, then proceeds to test the model for all requested checkpoints. :param config: The model configuration. :param data_split: The name of the folder to store the results inside each epoch folder in the outputs_dir, used mainly in model evaluation using different dataset splits. :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization :param model_proc: whether we are testing an ensemble or single model :return: InferenceMetricsForClassification object that contains metrics related for all of the checkpoint epochs. """ posthoc_label_transform = config.get_posthoc_label_transform() checkpoint_paths = checkpoint_handler.get_checkpoints_to_test() if not checkpoint_paths: raise ValueError("There were no checkpoints available for model testing.") pipeline = create_inference_pipeline(config=config, checkpoint_paths=checkpoint_paths) if pipeline is None: raise ValueError("Inference pipeline could not be created.") # for mypy assert isinstance(pipeline, ScalarInferencePipelineBase) ml_util.set_random_seed(config.get_effective_random_seed(), "Model Testing") ds = config.get_torch_dataset_for_inference(data_split).as_data_loader( shuffle=False, batch_size=1, num_dataload_workers=0 ) logging.info(f"Starting to evaluate model on {data_split.value} set.") results_folder = config.outputs_folder / get_best_epoch_results_path(data_split, model_proc) os.makedirs(str(results_folder), exist_ok=True) metrics_dict = create_metrics_dict_for_scalar_models(config) if not isinstance(config, SequenceModelBase): output_logger: Optional[DataframeLogger] = DataframeLogger(csv_path=results_folder / MODEL_OUTPUT_CSV) else: output_logger = None for sample in ds: result = pipeline.predict(sample) model_output = result.posteriors label = result.labels.to(device=model_output.device) label = posthoc_label_transform(label) sample_id = result.subject_ids[0] if output_logger: for i in range(len(config.target_names)): output_logger.add_record({LoggingColumns.Patient.value: sample_id, LoggingColumns.Hue.value: config.target_names[i], LoggingColumns.Label.value: label[0][i].item(), LoggingColumns.ModelOutput.value: model_output[0][i].item(), LoggingColumns.CrossValidationSplitIndex.value: cross_val_split_index}) compute_scalar_metrics(metrics_dict, subject_ids=[sample_id], model_output=model_output, labels=label, loss_type=config.loss_type) logging.debug(f"Example {sample_id}: {metrics_dict.to_string()}") average = metrics_dict.average(across_hues=False) logging.info(average.to_string()) if isinstance(metrics_dict, ScalarMetricsDict): csv_file = results_folder / SUBJECT_METRICS_FILE_NAME logging.info(f"Writing {data_split.value} metrics to file {str(csv_file)}") # If we are running inference after a training run, the validation set metrics may have been written # during train time. If this is not the case, or we are running on the test set, create the metrics # file. if not csv_file.exists(): df_logger = DataframeLogger(csv_file) # For test if ensemble split should be default, else record which fold produced this prediction cv_index = DEFAULT_CROSS_VALIDATION_SPLIT_INDEX if model_proc == ModelProcessing.ENSEMBLE_CREATION \ else cross_val_split_index metrics_dict.store_metrics_per_subject(df_logger=df_logger, mode=data_split, cross_validation_split_index=cv_index, epoch=BEST_EPOCH_FOLDER_NAME) # write to disk df_logger.flush() if output_logger: output_logger.flush() return InferenceMetricsForClassification(metrics=metrics_dict)