def compute_metrics(self, cropped_sample: CroppedSample, segmentation: torch.Tensor, is_training: bool) -> None: """ Computes and stores all metrics coming out of a single training step. :param cropped_sample: The batched image crops used for training or validation. :param segmentation: The segmentation that was produced by the model. :param is_training: If true, the method is called from `training_step`, otherwise it is called from `validation_step`. """ # dice_per_crop_and_class has one row per crop, with background class removed # Dice NaN means that both ground truth and prediction are empty. dice_per_crop_and_class = compute_dice_across_patches( segmentation=segmentation, ground_truth=cropped_sample.labels_center_crop, # type: ignore allow_multiple_classes_for_each_pixel=True)[:, 1:] # Number of foreground voxels per class, across all crops foreground_voxels = metrics_util.get_number_of_voxels_per_class( cropped_sample.labels)[:, 1:] # type: ignore # Store Dice and voxel count per sample in the minibatch. We need a custom aggregation logic for Dice # because it can be NaN. Also use custom logging for voxel count because Lightning's batch-size weighted # average has a bug. for i in range(dice_per_crop_and_class.shape[0]): dice = self.train_dice if is_training else self.val_dice dice.update(dice_per_crop_and_class[i, :]) voxel_count = self.train_voxels if is_training else self.val_voxels voxel_count.update(foreground_voxels[i, :]) # store diagnostics per batch center_indices = cropped_sample.center_indices if isinstance(center_indices, torch.Tensor): center_indices = center_indices.cpu().numpy() if is_training: self.storing_logger.train_diagnostics.append(center_indices) else: self.storing_logger.val_diagnostics.append(center_indices) if is_training and self.config.store_dataset_sample: # store the sample train patch from this epoch for visualization # remove batches and channels dataset_example = DatasetExample( image=cropped_sample.image[0][0].cpu().detach().numpy(), labels=cropped_sample.labels[0].cpu().detach().numpy(), prediction=segmentation[0].cpu().detach().numpy(), header=cropped_sample.metadata[0].image_header, # type: ignore patient_id=cropped_sample.metadata[0]. patient_id, # type: ignore epoch=self.current_epoch) store_and_upload_example(dataset_example, self.config) num_subjects = cropped_sample.image.shape[0] self.log_on_epoch(name=MetricType.SUBJECT_COUNT, value=num_subjects, is_training=is_training, reduce_fx=torch.sum)
def test_save_dataset_example(test_output_dirs: OutputFolderForTests) -> None: """ Test if the example dataset can be saved as expected. """ image_size = (10, 20, 30) label_size = (2, ) + image_size spacing = (1, 2, 3) np.random.seed(0) # Image should look similar to what a photonormalized image looks like: Centered around 0 image = np.random.rand(*image_size) * 2 - 1 # Labels are expected in one-hot encoding, predictions as class index labels = np.zeros(label_size, dtype=int) labels[0] = 1 labels[0, 5:6, 10:11, 15:16] = 0 labels[1, 5:6, 10:11, 15:16] = 1 prediction = np.zeros(image_size, dtype=int) prediction[4:7, 9:12, 14:17] = 1 dataset_sample = DatasetExample(epoch=1, patient_id=2, header=ImageHeader(origin=(0, 1, 0), direction=(1, 0, 0, 0, 1, 0, 0, 0, 1), spacing=spacing), image=image, prediction=prediction, labels=labels) images_folder = test_output_dirs.root_dir config = SegmentationModelBase( should_validate=False, norm_method=PhotometricNormalizationMethod.Unchanged) config.set_output_to(images_folder) store_and_upload_example(dataset_sample, config) image_from_disk = io_util.load_nifti_image( os.path.join(config.example_images_folder, "p2_e_1_image.nii.gz")) labels_from_disk = io_util.load_nifti_image( os.path.join(config.example_images_folder, "p2_e_1_label.nii.gz")) prediction_from_disk = io_util.load_nifti_image( os.path.join(config.example_images_folder, "p2_e_1_prediction.nii.gz")) assert image_from_disk.header.spacing == spacing # When no photometric normalization is provided when saving, image is multiplied by 1000. # It is then rounded to int64, but converted back to float when read back in. expected_from_disk = (image * 1000).astype(np.int16).astype(np.float64) assert np.array_equal(image_from_disk.image, expected_from_disk) assert labels_from_disk.header.spacing == spacing assert np.array_equal(labels_from_disk.image, np.argmax(labels, axis=0)) assert prediction_from_disk.header.spacing == spacing assert np.array_equal(prediction_from_disk.image, prediction)
def _store_dataset_sample(config: SegmentationModelBase, epoch: int, forward_pass_result: SegmentationForwardPass.Result, sample: CroppedSample) -> None: """ Stores the first sample in a batch, along with it's results from the model forward pass as Nifti to the file system. :param config: Training configurations. :param epoch: The epoch to which this sample belongs to. :param forward_pass_result: The result of a model forward pass. :param sample: The original crop sample used for training, as returned by the data loader """ # pick the first image from the batch as example example = DatasetExample( epoch=epoch, # noinspection PyTypeChecker patient_id=sample.metadata[0].patient_id, # type: ignore image=sample.image[0][0].numpy(), labels=sample.labels[0].numpy(), prediction=forward_pass_result.segmentations[0], header=sample.metadata[0].image_header) # type: ignore dataset_util.store_and_upload_example(dataset_example=example, args=config)