Example #1
0
 def register_model_for_best_epoch(
         self, checkpoint_handler: CheckpointHandler,
         test_metrics: Optional[InferenceMetricsForSegmentation],
         val_metrics: Optional[InferenceMetricsForSegmentation],
         model_proc: ModelProcessing) -> int:
     if val_metrics is not None:
         best_epoch = val_metrics.get_best_epoch()
         num_epochs = len(val_metrics.epochs)
         model_description = f"Epoch {best_epoch} has best validation set metrics (out of {num_epochs} epochs " \
                             f"available). Validation set Dice: {val_metrics.epochs[best_epoch]}. "
         if test_metrics:
             model_description += f"Test set Dice: {test_metrics.epochs[best_epoch]}."
         else:
             model_description += "Test set metrics not available."
     elif test_metrics is not None:
         # We should normally not get here. We presently always run inference on both validation and test set
         # together.
         best_epoch = test_metrics.get_best_epoch()
         num_epochs = len(test_metrics.epochs)
         model_description = f"Epoch {best_epoch} has best test set metrics (out of {num_epochs} epochs " \
                             f"available). Test set Dice: {test_metrics.epochs[best_epoch]}"
     else:
         best_epoch = self.model_config.get_test_epochs()[-1]
         model_description = f"Model for epoch {best_epoch}. No validation or test set metrics were available."
     checkpoint_paths = checkpoint_handler.get_checkpoint_paths_from_epoch_or_fail(
         best_epoch)
     self.register_model_for_epoch(checkpoint_paths, model_description,
                                   model_proc)
     return best_epoch
Example #2
0
    def run_inference_and_register_model(
            self, checkpoint_handler: CheckpointHandler,
            model_proc: ModelProcessing) -> Optional[int]:
        """
        Run inference as required, and register the model, but not necessarily in that order:
        if we can identify the epoch to register at without running inference, we register first.
        :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
        :param model_proc: whether we are running an ensemble model from within a child run with index 0. If we are,
        then outputs will be written to OTHER_RUNS/ENSEMBLE under the main outputs directory.
        """
        registration_epoch = self.decide_registration_epoch_without_evaluating(
        )
        if registration_epoch is not None:
            model_description = f"Registering model for epoch {registration_epoch} without considering metrics."
            checkpoint_paths = checkpoint_handler.get_checkpoint_paths_from_epoch_or_fail(
                registration_epoch)
            self.register_model_for_epoch(checkpoint_paths, model_description,
                                          model_proc)
            if self.azure_config.register_model_only_for_epoch is not None:
                return self.azure_config.register_model_only_for_epoch

        # run full image inference on existing or newly trained model on the training, and testing set
        test_metrics, val_metrics, _ = self.model_inference_train_and_test(
            checkpoint_handler=checkpoint_handler, model_proc=model_proc)

        # register the generated model from the run if we haven't already done so
        if self.model_config.is_segmentation_model and (
                not self.model_config.is_offline_run):
            if registration_epoch is None:
                if self.should_register_model():
                    assert test_metrics is None or isinstance(
                        test_metrics, InferenceMetricsForSegmentation)
                    assert val_metrics is None or isinstance(
                        val_metrics, InferenceMetricsForSegmentation)
                    registration_epoch = self.register_model_for_best_epoch(
                        checkpoint_handler, test_metrics, val_metrics,
                        model_proc)
            self.try_compare_scores_against_baselines(model_proc)
        else:
            logging.warning("Couldn't register model in offline mode")

        return registration_epoch