Ejemplo n.º 1
0
    def compute_metrics(self, cropped_sample: CroppedSample,
                        segmentation: torch.Tensor, is_training: bool) -> None:
        """
        Computes and stores all metrics coming out of a single training step.
        :param cropped_sample: The batched image crops used for training or validation.
        :param segmentation: The segmentation that was produced by the model.
        :param is_training: If true, the method is called from `training_step`, otherwise it is called from
        `validation_step`.
        """
        # dice_per_crop_and_class has one row per crop, with background class removed
        # Dice NaN means that both ground truth and prediction are empty.
        dice_per_crop_and_class = compute_dice_across_patches(
            segmentation=segmentation,
            ground_truth=cropped_sample.labels_center_crop,  # type: ignore
            allow_multiple_classes_for_each_pixel=True)[:, 1:]
        # Number of foreground voxels per class, across all crops
        foreground_voxels = metrics_util.get_number_of_voxels_per_class(
            cropped_sample.labels)[:, 1:]  # type: ignore
        # Store Dice and voxel count per sample in the minibatch. We need a custom aggregation logic for Dice
        # because it can be NaN. Also use custom logging for voxel count because Lightning's batch-size weighted
        # average has a bug.
        for i in range(dice_per_crop_and_class.shape[0]):
            dice = self.train_dice if is_training else self.val_dice
            dice.update(dice_per_crop_and_class[i, :])
            voxel_count = self.train_voxels if is_training else self.val_voxels
            voxel_count.update(foreground_voxels[i, :])
        # store diagnostics per batch
        center_indices = cropped_sample.center_indices
        if isinstance(center_indices, torch.Tensor):
            center_indices = center_indices.cpu().numpy()
        if is_training:
            self.storing_logger.train_diagnostics.append(center_indices)
        else:
            self.storing_logger.val_diagnostics.append(center_indices)

        if is_training and self.config.store_dataset_sample:
            # store the sample train patch from this epoch for visualization
            # remove batches and channels
            dataset_example = DatasetExample(
                image=cropped_sample.image[0][0].cpu().detach().numpy(),
                labels=cropped_sample.labels[0].cpu().detach().numpy(),
                prediction=segmentation[0].cpu().detach().numpy(),
                header=cropped_sample.metadata[0].image_header,  # type: ignore
                patient_id=cropped_sample.metadata[0].
                patient_id,  # type: ignore
                epoch=self.current_epoch)
            store_and_upload_example(dataset_example, self.config)

        num_subjects = cropped_sample.image.shape[0]
        self.log_on_epoch(name=MetricType.SUBJECT_COUNT,
                          value=num_subjects,
                          is_training=is_training,
                          reduce_fx=torch.sum)
Ejemplo n.º 2
0
def test_save_dataset_example(test_output_dirs: OutputFolderForTests) -> None:
    """
    Test if the example dataset can be saved as expected.
    """
    image_size = (10, 20, 30)
    label_size = (2, ) + image_size
    spacing = (1, 2, 3)
    np.random.seed(0)
    # Image should look similar to what a photonormalized image looks like: Centered around 0
    image = np.random.rand(*image_size) * 2 - 1
    # Labels are expected in one-hot encoding, predictions as class index
    labels = np.zeros(label_size, dtype=int)
    labels[0] = 1
    labels[0, 5:6, 10:11, 15:16] = 0
    labels[1, 5:6, 10:11, 15:16] = 1
    prediction = np.zeros(image_size, dtype=int)
    prediction[4:7, 9:12, 14:17] = 1
    dataset_sample = DatasetExample(epoch=1,
                                    patient_id=2,
                                    header=ImageHeader(origin=(0, 1, 0),
                                                       direction=(1, 0, 0, 0,
                                                                  1, 0, 0, 0,
                                                                  1),
                                                       spacing=spacing),
                                    image=image,
                                    prediction=prediction,
                                    labels=labels)

    images_folder = test_output_dirs.root_dir
    config = SegmentationModelBase(
        should_validate=False,
        norm_method=PhotometricNormalizationMethod.Unchanged)
    config.set_output_to(images_folder)
    store_and_upload_example(dataset_sample, config)
    image_from_disk = io_util.load_nifti_image(
        os.path.join(config.example_images_folder, "p2_e_1_image.nii.gz"))
    labels_from_disk = io_util.load_nifti_image(
        os.path.join(config.example_images_folder, "p2_e_1_label.nii.gz"))
    prediction_from_disk = io_util.load_nifti_image(
        os.path.join(config.example_images_folder, "p2_e_1_prediction.nii.gz"))
    assert image_from_disk.header.spacing == spacing
    # When no photometric normalization is provided when saving, image is multiplied by 1000.
    # It is then rounded to int64, but converted back to float when read back in.
    expected_from_disk = (image * 1000).astype(np.int16).astype(np.float64)
    assert np.array_equal(image_from_disk.image, expected_from_disk)
    assert labels_from_disk.header.spacing == spacing
    assert np.array_equal(labels_from_disk.image, np.argmax(labels, axis=0))
    assert prediction_from_disk.header.spacing == spacing
    assert np.array_equal(prediction_from_disk.image, prediction)
Ejemplo n.º 3
0
def _store_dataset_sample(config: SegmentationModelBase, epoch: int,
                          forward_pass_result: SegmentationForwardPass.Result,
                          sample: CroppedSample) -> None:
    """
    Stores the first sample in a batch, along with it's results from the model forward pass
    as Nifti to the file system.
    :param config: Training configurations.
    :param epoch: The epoch to which this sample belongs to.
    :param forward_pass_result: The result of a model forward pass.
    :param sample: The original crop sample used for training, as returned by the data loader
    """
    # pick the first image from the batch as example
    example = DatasetExample(
        epoch=epoch,
        # noinspection PyTypeChecker
        patient_id=sample.metadata[0].patient_id,  # type: ignore
        image=sample.image[0][0].numpy(),
        labels=sample.labels[0].numpy(),
        prediction=forward_pass_result.segmentations[0],
        header=sample.metadata[0].image_header)  # type: ignore
    dataset_util.store_and_upload_example(dataset_example=example, args=config)