def test_diagnostics() -> None:
    """
    Test if we can store diagnostic values (no restrictions on data types) in the metrics dict.
    """
    name = "foo"
    value1 = "something"
    value2 = (1, 2, 3)
    m = MetricsDict()
    m.add_diagnostics(name, value1)
    m.add_diagnostics(name, value2)
    assert m.diagnostics == {name: [value1, value2]}
def test_aggregate_segmentation_metrics() -> None:
    """
    Test how per-epoch segmentation metrics are aggregated to computed foreground dice and voxel count proportions.
    """
    g1 = "Liver"
    g2 = "Lung"
    ground_truth_ids = [BACKGROUND_CLASS_NAME, g1, g2]
    dice = [0.85, 0.75, 0.55]
    voxels_proportion = [0.85, 0.10, 0.05]
    loss = 3.14
    other_metric = 2.71
    m = MetricsDict(hues=ground_truth_ids)
    voxel_count = 200
    # Add 3 values per metric, but such that the averages are back at the value given in dice[i]
    for i in range(3):
        delta = (i - 1) * 0.05
        for j, ground_truth_id in enumerate(ground_truth_ids):
            m.add_metric(MetricType.DICE, dice[j] + delta, hue=ground_truth_id)
            m.add_metric(MetricType.VOXEL_COUNT, int(voxels_proportion[j] * voxel_count), hue=ground_truth_id)
        m.add_metric(MetricType.LOSS, loss + delta)
        m.add_metric("foo", other_metric)
    m.add_diagnostics("foo", "bar")
    aggregate = metrics.aggregate_segmentation_metrics(m)
    assert aggregate.diagnostics == m.diagnostics
    enumerated = list((g, s, v) for g, s, v in aggregate.enumerate_single_values())
    expected = [
        # Dice and voxel count per foreground structure should be retained during averaging
        (g1, MetricType.DICE.value, dice[1]),
        (g1, MetricType.VOXEL_COUNT.value, voxels_proportion[1] * voxel_count),
        # Proportion of foreground voxels is computed during averaging
        (g1, MetricType.PROPORTION_FOREGROUND_VOXELS.value, voxels_proportion[1]),
        (g2, MetricType.DICE.value, dice[2]),
        (g2, MetricType.VOXEL_COUNT.value, voxels_proportion[2] * voxel_count),
        (g2, MetricType.PROPORTION_FOREGROUND_VOXELS.value, voxels_proportion[2]),
        # Loss is present in the default metrics group, and should be retained.
        (MetricsDict.DEFAULT_HUE_KEY, MetricType.LOSS.value, loss),
        (MetricsDict.DEFAULT_HUE_KEY, "foo", other_metric),
        # Dice averaged across the foreground structures is added during the function call, as is proportion of voxels
        (MetricsDict.DEFAULT_HUE_KEY, MetricType.DICE.value, 0.5 * (dice[1] + dice[2])),
        (MetricsDict.DEFAULT_HUE_KEY, MetricType.PROPORTION_FOREGROUND_VOXELS.value,
         voxels_proportion[1] + voxels_proportion[2]),
    ]
    assert len(enumerated) == len(expected)
    # Numbers won't match up precisely because of rounding during averaging
    for (actual, e) in zip(enumerated, expected):
        assert actual[0:2] == e[0:2]
        assert actual[2] == pytest.approx(e[2])
Пример #3
0
class ModelTrainingStepsForSegmentation(
        ModelTrainingStepsBase[SegmentationModelBase, DeviceAwareModule]):
    """
    This class implements all steps necessary for training an image segmentation model during a single epoch.
    """
    def __init__(self, model_config: SegmentationModelBase,
                 train_val_params: TrainValidateParameters[DeviceAwareModule]):
        """
        Creates a new instance of the class.
        :param model_config: The configuration of a segmentation model.
        :param train_val_params: The parameters for training the model, including the optimizer and the data loaders.
        """
        super().__init__(model_config, train_val_params)
        self.example_to_save = np.random.randint(
            0, len(train_val_params.data_loader))
        self.pipeline = SegmentationForwardPass(
            model=self.train_val_params.model,
            model_config=self.model_config,
            batch_size=self.model_config.train_batch_size,
            optimizer=self.train_val_params.optimizer,
            in_training_mode=self.train_val_params.in_training_mode,
            criterion=self.compute_loss,
            gradient_scaler=train_val_params.gradient_scaler)
        self.metrics = MetricsDict(hues=[BACKGROUND_CLASS_NAME] +
                                   model_config.ground_truth_ids)

    def create_loss_function(self) -> torch.nn.Module:
        """
        Returns a torch module that computes a loss function.
        """
        return self.construct_loss_function(self.model_config)

    @classmethod
    def construct_loss_function(
            cls, model_config: SegmentationModelBase
    ) -> SupervisedLearningCriterion:
        """
        Returns a loss function from the model config; mixture losses are constructed as weighted combinations of
        other loss functions.
        """
        if model_config.loss_type == SegmentationLoss.Mixture:
            components = model_config.mixture_loss_components
            assert components is not None
            sum_weights = sum(component.weight for component in components)
            weights_and_losses = []
            for component in components:
                normalized_weight = component.weight / sum_weights
                loss_function = cls.construct_non_mixture_loss_function(
                    model_config, component.loss_type,
                    component.class_weight_power)
                weights_and_losses.append((normalized_weight, loss_function))
            return MixtureLoss(weights_and_losses)
        return cls.construct_non_mixture_loss_function(
            model_config, model_config.loss_type,
            model_config.loss_class_weight_power)

    @classmethod
    def construct_non_mixture_loss_function(
            cls, model_config: SegmentationModelBase,
            loss_type: SegmentationLoss,
            power: Optional[float]) -> SupervisedLearningCriterion:
        """
        :param model_config: model configuration to get some parameters from
        :param loss_type: type of loss function
        :param power: value for class_weight_power for the loss function
        :return: instance of loss function
        """
        if loss_type == SegmentationLoss.SoftDice:
            return SoftDiceLoss(class_weight_power=power)
        elif loss_type == SegmentationLoss.CrossEntropy:
            return CrossEntropyLoss(
                class_weight_power=power,
                smoothing_eps=model_config.label_smoothing_eps,
                focal_loss_gamma=None)
        elif loss_type == SegmentationLoss.Focal:
            return CrossEntropyLoss(
                class_weight_power=power,
                smoothing_eps=model_config.label_smoothing_eps,
                focal_loss_gamma=model_config.focal_loss_gamma)
        else:
            raise NotImplementedError(
                "Loss type {} is not implemented".format(loss_type))

    def forward_and_backward_minibatch(
            self, sample: Dict[str, Any], batch_index: int,
            epoch: int) -> ModelForwardAndBackwardsOutputs:
        """
        Runs training for a single minibatch of training data, and computes all metrics.
        :param sample: The batched sample on which the model should be trained.
        :param batch_index: The index of the present batch (supplied only for diagnostics).
        :param epoch: The number of the present epoch.
        """
        cropped_sample: CroppedSample = CroppedSample.from_dict(sample=sample)
        labels = self.model_config.get_gpu_tensor_if_possible(
            cropped_sample.labels_center_crop)

        mask = None if self.train_val_params.in_training_mode else cropped_sample.mask_center_crop
        forward_pass_result = self.pipeline.forward_pass_patches(
            patches=cropped_sample.image, labels=labels, mask=mask)
        # Clear the GPU cache between forward and backward passes to avoid possible out-of-memory
        torch.cuda.empty_cache()
        dice_for_all_classes = metrics.compute_dice_across_patches(
            segmentation=torch.tensor(
                forward_pass_result.segmentations).long(),
            ground_truth=labels,
            use_cuda=self.model_config.use_gpu,
            allow_multiple_classes_for_each_pixel=True).cpu().numpy()
        foreground_voxels = metrics_util.get_number_of_voxels_per_class(
            cropped_sample.labels)
        # loss is a scalar, also when running the forward pass over multiple crops.
        # dice_for_all_structures has one row per crop.
        if forward_pass_result.loss is None:
            raise ValueError(
                "During training, the loss should always be computed, but the value is None."
            )
        loss = forward_pass_result.loss

        # store metrics per batch
        self.metrics.add_metric(MetricType.LOSS, loss)
        for i, ground_truth_id in enumerate(
                self.metrics.get_hue_names(include_default=False)):
            for b in range(dice_for_all_classes.shape[0]):
                self.metrics.add_metric(MetricType.DICE,
                                        dice_for_all_classes[b, i].item(),
                                        hue=ground_truth_id,
                                        skip_nan_when_averaging=True)
            self.metrics.add_metric(MetricType.VOXEL_COUNT,
                                    foreground_voxels[i],
                                    hue=ground_truth_id)
        # store diagnostics per batch
        center_indices = cropped_sample.center_indices
        if isinstance(center_indices, torch.Tensor):
            center_indices = center_indices.cpu().numpy()
        self.metrics.add_diagnostics(MetricType.PATCH_CENTER.value,
                                     np.copy(center_indices))
        if self.train_val_params.in_training_mode:
            # store the sample train patch from this epoch for visualization
            if batch_index == self.example_to_save and self.model_config.store_dataset_sample:
                _store_dataset_sample(self.model_config,
                                      self.train_val_params.epoch,
                                      forward_pass_result, cropped_sample)

        return ModelForwardAndBackwardsOutputs(
            loss=loss,
            logits=forward_pass_result.posteriors,
            labels=forward_pass_result.segmentations)

    def get_epoch_results_and_store(self,
                                    epoch_time_seconds: float) -> MetricsDict:
        """
        Assembles all training results that were achieved over all minibatches, writes them to Tensorboard and
        AzureML, and returns them as a MetricsDict object.
        :param epoch_time_seconds: For diagnostics, this is the total time in seconds for training the present epoch.
        :return: A dictionary that holds all metrics averaged over the epoch.
        """
        self.metrics.add_metric(MetricType.SECONDS_PER_EPOCH,
                                epoch_time_seconds)
        assert len(self.train_val_params.epoch_learning_rate
                   ) == 1, "Expected a single entry for learning rate."
        self.metrics.add_metric(MetricType.LEARNING_RATE,
                                self.train_val_params.epoch_learning_rate[0])
        result = metrics.aggregate_segmentation_metrics(self.metrics)
        metrics.store_epoch_metrics(self.azure_and_tensorboard_logger,
                                    self.df_logger,
                                    self.train_val_params.epoch, result,
                                    self.train_val_params.epoch_learning_rate,
                                    self.model_config)
        return result