def test_diagnostics() -> None: """ Test if we can store diagnostic values (no restrictions on data types) in the metrics dict. """ name = "foo" value1 = "something" value2 = (1, 2, 3) m = MetricsDict() m.add_diagnostics(name, value1) m.add_diagnostics(name, value2) assert m.diagnostics == {name: [value1, value2]}
def test_aggregate_segmentation_metrics() -> None: """ Test how per-epoch segmentation metrics are aggregated to computed foreground dice and voxel count proportions. """ g1 = "Liver" g2 = "Lung" ground_truth_ids = [BACKGROUND_CLASS_NAME, g1, g2] dice = [0.85, 0.75, 0.55] voxels_proportion = [0.85, 0.10, 0.05] loss = 3.14 other_metric = 2.71 m = MetricsDict(hues=ground_truth_ids) voxel_count = 200 # Add 3 values per metric, but such that the averages are back at the value given in dice[i] for i in range(3): delta = (i - 1) * 0.05 for j, ground_truth_id in enumerate(ground_truth_ids): m.add_metric(MetricType.DICE, dice[j] + delta, hue=ground_truth_id) m.add_metric(MetricType.VOXEL_COUNT, int(voxels_proportion[j] * voxel_count), hue=ground_truth_id) m.add_metric(MetricType.LOSS, loss + delta) m.add_metric("foo", other_metric) m.add_diagnostics("foo", "bar") aggregate = metrics.aggregate_segmentation_metrics(m) assert aggregate.diagnostics == m.diagnostics enumerated = list((g, s, v) for g, s, v in aggregate.enumerate_single_values()) expected = [ # Dice and voxel count per foreground structure should be retained during averaging (g1, MetricType.DICE.value, dice[1]), (g1, MetricType.VOXEL_COUNT.value, voxels_proportion[1] * voxel_count), # Proportion of foreground voxels is computed during averaging (g1, MetricType.PROPORTION_FOREGROUND_VOXELS.value, voxels_proportion[1]), (g2, MetricType.DICE.value, dice[2]), (g2, MetricType.VOXEL_COUNT.value, voxels_proportion[2] * voxel_count), (g2, MetricType.PROPORTION_FOREGROUND_VOXELS.value, voxels_proportion[2]), # Loss is present in the default metrics group, and should be retained. (MetricsDict.DEFAULT_HUE_KEY, MetricType.LOSS.value, loss), (MetricsDict.DEFAULT_HUE_KEY, "foo", other_metric), # Dice averaged across the foreground structures is added during the function call, as is proportion of voxels (MetricsDict.DEFAULT_HUE_KEY, MetricType.DICE.value, 0.5 * (dice[1] + dice[2])), (MetricsDict.DEFAULT_HUE_KEY, MetricType.PROPORTION_FOREGROUND_VOXELS.value, voxels_proportion[1] + voxels_proportion[2]), ] assert len(enumerated) == len(expected) # Numbers won't match up precisely because of rounding during averaging for (actual, e) in zip(enumerated, expected): assert actual[0:2] == e[0:2] assert actual[2] == pytest.approx(e[2])
class ModelTrainingStepsForSegmentation( ModelTrainingStepsBase[SegmentationModelBase, DeviceAwareModule]): """ This class implements all steps necessary for training an image segmentation model during a single epoch. """ def __init__(self, model_config: SegmentationModelBase, train_val_params: TrainValidateParameters[DeviceAwareModule]): """ Creates a new instance of the class. :param model_config: The configuration of a segmentation model. :param train_val_params: The parameters for training the model, including the optimizer and the data loaders. """ super().__init__(model_config, train_val_params) self.example_to_save = np.random.randint( 0, len(train_val_params.data_loader)) self.pipeline = SegmentationForwardPass( model=self.train_val_params.model, model_config=self.model_config, batch_size=self.model_config.train_batch_size, optimizer=self.train_val_params.optimizer, in_training_mode=self.train_val_params.in_training_mode, criterion=self.compute_loss, gradient_scaler=train_val_params.gradient_scaler) self.metrics = MetricsDict(hues=[BACKGROUND_CLASS_NAME] + model_config.ground_truth_ids) def create_loss_function(self) -> torch.nn.Module: """ Returns a torch module that computes a loss function. """ return self.construct_loss_function(self.model_config) @classmethod def construct_loss_function( cls, model_config: SegmentationModelBase ) -> SupervisedLearningCriterion: """ Returns a loss function from the model config; mixture losses are constructed as weighted combinations of other loss functions. """ if model_config.loss_type == SegmentationLoss.Mixture: components = model_config.mixture_loss_components assert components is not None sum_weights = sum(component.weight for component in components) weights_and_losses = [] for component in components: normalized_weight = component.weight / sum_weights loss_function = cls.construct_non_mixture_loss_function( model_config, component.loss_type, component.class_weight_power) weights_and_losses.append((normalized_weight, loss_function)) return MixtureLoss(weights_and_losses) return cls.construct_non_mixture_loss_function( model_config, model_config.loss_type, model_config.loss_class_weight_power) @classmethod def construct_non_mixture_loss_function( cls, model_config: SegmentationModelBase, loss_type: SegmentationLoss, power: Optional[float]) -> SupervisedLearningCriterion: """ :param model_config: model configuration to get some parameters from :param loss_type: type of loss function :param power: value for class_weight_power for the loss function :return: instance of loss function """ if loss_type == SegmentationLoss.SoftDice: return SoftDiceLoss(class_weight_power=power) elif loss_type == SegmentationLoss.CrossEntropy: return CrossEntropyLoss( class_weight_power=power, smoothing_eps=model_config.label_smoothing_eps, focal_loss_gamma=None) elif loss_type == SegmentationLoss.Focal: return CrossEntropyLoss( class_weight_power=power, smoothing_eps=model_config.label_smoothing_eps, focal_loss_gamma=model_config.focal_loss_gamma) else: raise NotImplementedError( "Loss type {} is not implemented".format(loss_type)) def forward_and_backward_minibatch( self, sample: Dict[str, Any], batch_index: int, epoch: int) -> ModelForwardAndBackwardsOutputs: """ Runs training for a single minibatch of training data, and computes all metrics. :param sample: The batched sample on which the model should be trained. :param batch_index: The index of the present batch (supplied only for diagnostics). :param epoch: The number of the present epoch. """ cropped_sample: CroppedSample = CroppedSample.from_dict(sample=sample) labels = self.model_config.get_gpu_tensor_if_possible( cropped_sample.labels_center_crop) mask = None if self.train_val_params.in_training_mode else cropped_sample.mask_center_crop forward_pass_result = self.pipeline.forward_pass_patches( patches=cropped_sample.image, labels=labels, mask=mask) # Clear the GPU cache between forward and backward passes to avoid possible out-of-memory torch.cuda.empty_cache() dice_for_all_classes = metrics.compute_dice_across_patches( segmentation=torch.tensor( forward_pass_result.segmentations).long(), ground_truth=labels, use_cuda=self.model_config.use_gpu, allow_multiple_classes_for_each_pixel=True).cpu().numpy() foreground_voxels = metrics_util.get_number_of_voxels_per_class( cropped_sample.labels) # loss is a scalar, also when running the forward pass over multiple crops. # dice_for_all_structures has one row per crop. if forward_pass_result.loss is None: raise ValueError( "During training, the loss should always be computed, but the value is None." ) loss = forward_pass_result.loss # store metrics per batch self.metrics.add_metric(MetricType.LOSS, loss) for i, ground_truth_id in enumerate( self.metrics.get_hue_names(include_default=False)): for b in range(dice_for_all_classes.shape[0]): self.metrics.add_metric(MetricType.DICE, dice_for_all_classes[b, i].item(), hue=ground_truth_id, skip_nan_when_averaging=True) self.metrics.add_metric(MetricType.VOXEL_COUNT, foreground_voxels[i], hue=ground_truth_id) # store diagnostics per batch center_indices = cropped_sample.center_indices if isinstance(center_indices, torch.Tensor): center_indices = center_indices.cpu().numpy() self.metrics.add_diagnostics(MetricType.PATCH_CENTER.value, np.copy(center_indices)) if self.train_val_params.in_training_mode: # store the sample train patch from this epoch for visualization if batch_index == self.example_to_save and self.model_config.store_dataset_sample: _store_dataset_sample(self.model_config, self.train_val_params.epoch, forward_pass_result, cropped_sample) return ModelForwardAndBackwardsOutputs( loss=loss, logits=forward_pass_result.posteriors, labels=forward_pass_result.segmentations) def get_epoch_results_and_store(self, epoch_time_seconds: float) -> MetricsDict: """ Assembles all training results that were achieved over all minibatches, writes them to Tensorboard and AzureML, and returns them as a MetricsDict object. :param epoch_time_seconds: For diagnostics, this is the total time in seconds for training the present epoch. :return: A dictionary that holds all metrics averaged over the epoch. """ self.metrics.add_metric(MetricType.SECONDS_PER_EPOCH, epoch_time_seconds) assert len(self.train_val_params.epoch_learning_rate ) == 1, "Expected a single entry for learning rate." self.metrics.add_metric(MetricType.LEARNING_RATE, self.train_val_params.epoch_learning_rate[0]) result = metrics.aggregate_segmentation_metrics(self.metrics) metrics.store_epoch_metrics(self.azure_and_tensorboard_logger, self.df_logger, self.train_val_params.epoch, result, self.train_val_params.epoch_learning_rate, self.model_config) return result