def test_sample(random_image_crop: Any, random_mask_crop: Any, random_label_crop: Any, random_patient_id: Any) -> None:
    """
    Tests that after creating and extracting a sample we obtain the same result
    :return:
    """
    metadata = PatientMetadata(patient_id=42, institution="foo")
    sample = Sample(image=random_image_crop,
                    mask=random_mask_crop,
                    labels=random_label_crop,
                    metadata=metadata)

    patched_sample = CroppedSample(image=random_image_crop,
                                   mask=random_mask_crop,
                                   labels=random_label_crop,
                                   mask_center_crop=random_mask_crop,
                                   labels_center_crop=random_label_crop,
                                   metadata=metadata,
                                   center_indices=np.zeros((1, 3)))

    extracted_sample = sample.get_dict()
    extracted_patched_sample = patched_sample.get_dict()

    sample_and_patched_sample_equal: Callable[[str, Any], bool] \
        = lambda k, x: bool(
        np.array_equal(extracted_sample[k], extracted_patched_sample[k]) and np.array_equal(extracted_patched_sample[k],
                                                                                            x))

    assert sample_and_patched_sample_equal("image", random_image_crop)
    assert sample_and_patched_sample_equal("mask", random_mask_crop)
    assert sample_and_patched_sample_equal("labels", random_label_crop)

    assert np.array_equal(extracted_patched_sample["mask_center_crop"], random_mask_crop)
    assert np.array_equal(extracted_patched_sample["labels_center_crop"], random_label_crop)
    assert extracted_sample["metadata"] == extracted_patched_sample["metadata"] == metadata
Example #2
0
def test_cropping_dataset_as_data_loader(cropping_dataset: CroppingDataset,
                                         num_dataload_workers: int) -> None:
    batch_size = 2
    loader = cropping_dataset.as_data_loader(
        shuffle=True,
        batch_size=batch_size,
        num_dataload_workers=num_dataload_workers)
    for i, item in enumerate(loader):
        item = CroppedSample.from_dict(sample=item)
        assert item is not None
        assert item.image.shape == \
               (batch_size,
                cropping_dataset.args.number_of_image_channels) + cropping_dataset.args.crop_size  # type: ignore
        assert item.mask.shape == (
            batch_size, ) + cropping_dataset.args.crop_size  # type: ignore
        assert item.labels.shape == \
               (batch_size, cropping_dataset.args.number_of_classes) + cropping_dataset.args.crop_size  # type: ignore
        # check the mask center crops are as expected
        assert item.mask_center_crop.shape == (
            batch_size, ) + cropping_dataset.args.center_size  # type: ignore
        assert item.labels_center_crop.shape == \
               (batch_size, cropping_dataset.args.number_of_classes) + cropping_dataset.args.center_size  # type: ignore

        # check the contents of the center crops
        for b in range(batch_size):
            expected = image_util.get_center_crop(
                image=item.mask[b],
                crop_shape=cropping_dataset.args.center_size)
            assert np.array_equal(item.mask_center_crop[b], expected)

            for c in range(len(item.labels_center_crop[b])):
                expected = image_util.get_center_crop(
                    image=item.labels[b][c],
                    crop_shape=cropping_dataset.args.center_size)
                assert np.array_equal(item.labels_center_crop[b][c], expected)
Example #3
0
 def check_patient_id_in_dataset(loader: DataLoader, split: pd.DataFrame) -> None:
     subjects = list(split.subject.unique())
     for i, x in enumerate(loader):
         sample_from_loader = CroppedSample.from_dict(x)
         assert isinstance(sample_from_loader.metadata, list)
         assert len(sample_from_loader.metadata) == 1
         assert sample_from_loader.metadata[0].patient_id in subjects
def test_cropping_dataset_padding(cropping_dataset: CroppingDataset, num_dataload_workers: int) -> None:
    """
    Tests the data type of torch tensors (e.g. image, labels, and mask) created by the dataset generator,
    which are provided as input into the computational graph
    :return:
    """
    cropping_dataset.args.crop_size = (300, 300, 300)
    cropping_dataset.args.padding_mode = PaddingMode.Zero
    loader = cropping_dataset.as_data_loader(shuffle=True, batch_size=2, num_dataload_workers=1)

    for i, item in enumerate(loader):
        sample = CroppedSample.from_dict(item)
        assert sample.image.shape[-3:] == cropping_dataset.args.crop_size
Example #5
0
def extract_activation_maps(args: ModelConfigBase) -> None:
    """
    Extracts and saves activation maps of a specific layer of a trained network
    :param args:
    :return:
    """
    model = create_model_with_temperature_scaling(args)
    if args.use_gpu:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(
                                          range(torch.cuda.device_count())))
        model = model.cuda()

    checkpoint_path = args.get_path_to_checkpoint()
    if checkpoint_path.is_file():
        checkpoint = torch.load(checkpoint_path)  # type: ignore
        model.load_state_dict(checkpoint['state_dict'])
    else:
        raise FileNotFoundError("Could not find checkpoint")

    model.eval()

    val_loader = args.create_data_loaders()[ModelExecutionMode.VAL]

    feature_extractor = model_hooks.HookBasedFeatureExtractor(
        model, layer_name=args.activation_map_layers)

    for batch, sample in enumerate(val_loader):

        sample = CroppedSample.from_dict(sample=sample)

        input_image = sample.image.cuda().float()

        feature_extractor(input_image)

        # access first image of batch of feature maps
        activation_map = feature_extractor.outputs[0][0].cpu().numpy()

        if len(activation_map.shape) == 4:
            visualize_3d_activation_map(activation_map, args)

        elif len(activation_map.shape) == 3:
            visualize_2d_activation_map(activation_map, args)

        else:
            raise NotImplementedError(
                'cannot visualize activation map of shape',
                activation_map.shape)

        # Only visualize the first validation example
        break
def test_cropping_dataset_has_reproducible_randomness(cropping_dataset: CroppingDataset,
                                                      num_dataload_workers: int) -> None:
    cropping_dataset.dataset_indices = [1, 2] * 2
    expected_center_indices = None
    for k in range(3):
        ml_util.set_random_seed(1)
        loader = cropping_dataset.as_data_loader(shuffle=True, batch_size=4,
                                                 num_dataload_workers=num_dataload_workers)
        for i, item in enumerate(loader):
            item = CroppedSample.from_dict(sample=item)
            if expected_center_indices is None:
                expected_center_indices = item.center_indices
            else:
                assert np.array_equal(expected_center_indices, item.center_indices)
def test_cropping_dataset_sample_dtype(cropping_dataset: CroppingDataset, num_dataload_workers: int) -> None:
    """
    Tests the data type of torch tensors (e.g. image, labels, and mask) created by the dataset generator,
    which are provided as input into the computational graph
    :return:
    """
    loader = cropping_dataset.as_data_loader(shuffle=True, batch_size=2,
                                             num_dataload_workers=num_dataload_workers)
    for i, item in enumerate(loader):
        item = CroppedSample.from_dict(item)
        assert item.image.numpy().dtype == ImageDataType.IMAGE.value
        assert item.labels.numpy().dtype == ImageDataType.SEGMENTATION.value
        assert item.mask.numpy().dtype == ImageDataType.MASK.value
        assert item.mask_center_crop.numpy().dtype == ImageDataType.MASK.value
        assert item.labels_center_crop.numpy().dtype == ImageDataType.SEGMENTATION.value
Example #8
0
    def create_random_cropped_sample(
            sample: Sample,
            crop_size: TupleInt3,
            center_size: TupleInt3,
            class_weights: Optional[List[float]] = None) -> CroppedSample:
        """
        Creates an instance of a cropped sample extracted from full 3D images.
        :param sample: the full size 3D sample to use for extracting a cropped sample.
        :param crop_size: the size of the crop to extract.
        :param center_size: the size of the center of the crop (this should be the same as the spatial dimensions
                            of the posteriors that the model produces)
        :param class_weights: the distribution to use for the crop center class.
        :return: CroppedSample
        """
        # crop the original raw sample
        sample, center_point = random_crop(sample=sample,
                                           crop_size=crop_size,
                                           class_weights=class_weights)

        # crop the mask and label centers if required
        if center_size == crop_size:
            mask_center_crop = sample.mask
            labels_center_crop = sample.labels
        else:
            mask_center_crop = image_util.get_center_crop(
                image=sample.mask, crop_shape=center_size)
            labels_center_crop = np.zeros(
                shape=[len(sample.labels)] + list(center_size),  # type: ignore
                dtype=ImageDataType.SEGMENTATION.value)
            for c in range(len(sample.labels)):  # type: ignore
                labels_center_crop[c] = image_util.get_center_crop(
                    image=sample.labels[c], crop_shape=center_size)

        return CroppedSample(image=sample.image,
                             mask=sample.mask,
                             labels=sample.labels,
                             mask_center_crop=mask_center_crop,
                             labels_center_crop=labels_center_crop,
                             center_indices=center_point,
                             metadata=sample.metadata)
Example #9
0
    def training_or_validation_step(self, sample: Dict[str,
                                                       Any], batch_index: int,
                                    is_training: bool) -> torch.Tensor:
        """
        Runs training for a single minibatch of training or validation data, and computes all metrics.
        :param is_training: If true, the method is called from `training_step`, otherwise it is called from
        `validation_step`.
        :param sample: The batched sample on which the model should be trained.
        :param batch_index: The index of the present batch (supplied only for diagnostics).
        """
        cropped_sample: CroppedSample = CroppedSample.from_dict(sample=sample)
        # Forward propagation can lead to a model output that is smaller than the input image (crop).
        # labels_center_crop is the relevant part of the labels tensor that the model will actually produce.
        labels = cropped_sample.labels_center_crop

        mask = cropped_sample.mask_center_crop if is_training else None
        if is_training:
            logits = self.model(cropped_sample.image)
        else:
            with torch.no_grad():
                logits = self.model(cropped_sample.image)
        loss = self.loss_fn(logits, labels)

        # apply Softmax on dimension 1 (Class) to map model output into a posterior probability distribution [0,1]
        posteriors = self.logits_to_posterior(logits)

        # apply mask if required
        if mask is not None:
            posteriors = image_util.apply_mask_to_posteriors(
                posteriors=posteriors, mask=mask)  # type: ignore

        # post process posteriors to compute result
        segmentation = image_util.posteriors_to_segmentation(
            posteriors=posteriors)  # type: ignore
        self.compute_metrics(cropped_sample, segmentation,
                             is_training)  # type: ignore

        self.write_loss(is_training, loss)
        return loss
Example #10
0
    def forward_and_backward_minibatch(
            self, sample: Dict[str, Any], batch_index: int,
            epoch: int) -> ModelForwardAndBackwardsOutputs:
        """
        Runs training for a single minibatch of training data, and computes all metrics.
        :param sample: The batched sample on which the model should be trained.
        :param batch_index: The index of the present batch (supplied only for diagnostics).
        :param epoch: The number of the present epoch.
        """
        cropped_sample: CroppedSample = CroppedSample.from_dict(sample=sample)
        labels = self.model_config.get_gpu_tensor_if_possible(
            cropped_sample.labels_center_crop)

        mask = None if self.train_val_params.in_training_mode else cropped_sample.mask_center_crop
        forward_pass_result = self.pipeline.forward_pass_patches(
            patches=cropped_sample.image, labels=labels, mask=mask)
        # Clear the GPU cache between forward and backward passes to avoid possible out-of-memory
        torch.cuda.empty_cache()
        dice_for_all_classes = metrics.compute_dice_across_patches(
            segmentation=torch.tensor(
                forward_pass_result.segmentations).long(),
            ground_truth=labels,
            use_cuda=self.model_config.use_gpu,
            allow_multiple_classes_for_each_pixel=True).cpu().numpy()
        foreground_voxels = metrics_util.get_number_of_voxels_per_class(
            cropped_sample.labels)
        # loss is a scalar, also when running the forward pass over multiple crops.
        # dice_for_all_structures has one row per crop.
        if forward_pass_result.loss is None:
            raise ValueError(
                "During training, the loss should always be computed, but the value is None."
            )
        loss = forward_pass_result.loss

        # store metrics per batch
        self.metrics.add_metric(MetricType.LOSS, loss)
        for i, ground_truth_id in enumerate(
                self.metrics.get_hue_names(include_default=False)):
            for b in range(dice_for_all_classes.shape[0]):
                self.metrics.add_metric(MetricType.DICE,
                                        dice_for_all_classes[b, i].item(),
                                        hue=ground_truth_id,
                                        skip_nan_when_averaging=True)
            self.metrics.add_metric(MetricType.VOXEL_COUNT,
                                    foreground_voxels[i],
                                    hue=ground_truth_id)
        # store diagnostics per batch
        center_indices = cropped_sample.center_indices
        if isinstance(center_indices, torch.Tensor):
            center_indices = center_indices.cpu().numpy()
        self.metrics.add_diagnostics(MetricType.PATCH_CENTER.value,
                                     np.copy(center_indices))
        if self.train_val_params.in_training_mode:
            # store the sample train patch from this epoch for visualization
            if batch_index == self.example_to_save and self.model_config.store_dataset_sample:
                _store_dataset_sample(self.model_config,
                                      self.train_val_params.epoch,
                                      forward_pass_result, cropped_sample)

        return ModelForwardAndBackwardsOutputs(
            loss=loss,
            logits=forward_pass_result.posteriors,
            labels=forward_pass_result.segmentations)