Example #1
0
    def test_inverse_inferred_seg(self):

        test_data = []
        for _ in range(20):
            image, label = create_test_image_2d(100, 101)
            test_data.append({"image": image, "label": label.astype(np.float32)})

        batch_size = 10
        # num workers = 0 for mac
        num_workers = 2 if sys.platform != "darwin" else 0
        transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), CenterSpatialCropd(KEYS, (110, 99))])
        num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform))

        dataset = CacheDataset(test_data, transform=transforms, progress=False)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = UNet(
            dimensions=2,
            in_channels=1,
            out_channels=1,
            channels=(2, 4),
            strides=(2,),
        ).to(device)

        data = first(loader)
        labels = data["label"].to(device)
        segs = model(labels).detach().cpu()
        label_transform_key = "label" + InverseKeys.KEY_SUFFIX
        segs_dict = {"label": segs, label_transform_key: data[label_transform_key]}

        segs_dict_decollated = decollate_batch(segs_dict)
        # inverse of individual segmentation
        seg_dict = first(segs_dict_decollated)
        # test to convert interpolation mode for 1 data of model output batch
        convert_inverse_interp_mode(seg_dict, mode="nearest", align_corners=None)

        with allow_missing_keys_mode(transforms):
            inv_seg = transforms.inverse(seg_dict)["label"]
        self.assertEqual(len(data["label_transforms"]), num_invertible_transforms)
        self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms)
        self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)

        # Inverse of batch
        batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=no_collation)
        with allow_missing_keys_mode(transforms):
            inv_batch = batch_inverter(segs_dict)
        self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape)
    def __call__(
        self,
        data: Dict[str, Any],
        num_examples: int = 10
    ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, float], np.ndarray]:
        """
        Args:
            data: dictionary data to be processed.
            num_examples: number of realisations to be processed and results combined.

        Returns:
            - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are calculated across
                `num_examples` outputs at each voxel. The volume variation coefficient (VVC) is `std/mean` across the whole output,
                including `num_examples`. See original paper for clarification.
            - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then concatenating across
                the first dimension containing `num_examples`. This allows the user to perform their own analysis if desired.
        """
        d = dict(data)

        # check num examples is multiple of batch size
        if num_examples % self.batch_size != 0:
            raise ValueError("num_examples should be multiple of batch size.")

        # generate batch of data of size == batch_size, dataset and dataloader
        data_in = [d] * num_examples
        ds = Dataset(data_in, self.transform)
        dl = DataLoader(ds,
                        self.num_workers,
                        batch_size=self.batch_size,
                        collate_fn=pad_list_data_collate)

        label_transform_key = self.label_key + InverseKeys.KEY_SUFFIX

        # create inverter
        inverter = BatchInverseTransform(self.transform,
                                         dl,
                                         collate_fn=list_data_collate)

        outputs: List[np.ndarray] = []

        for batch_data in dl:

            batch_images = batch_data[self.image_key].to(self.device)

            # do model forward pass
            batch_output = self.inferrer_fn(batch_images)
            if isinstance(batch_output, torch.Tensor):
                batch_output = batch_output.detach().cpu()
            if isinstance(batch_output, np.ndarray):
                batch_output = torch.Tensor(batch_output)

            # create a dictionary containing the inferred batch and their transforms
            inferred_dict = {
                self.label_key: batch_output,
                label_transform_key: batch_data[label_transform_key]
            }

            # do inverse transformation (allow missing keys as only inverting label)
            with allow_missing_keys_mode(self.transform):  # type: ignore
                inv_batch = inverter(inferred_dict)

            # append
            outputs.append(inv_batch[self.label_key])

        # output
        output: np.ndarray = np.concatenate(outputs)

        if self.return_full_data:
            return output

        # calculate metrics
        mode: np.ndarray = np.apply_along_axis(
            lambda x: np.bincount(x).argmax(),
            axis=0,
            arr=output.astype(np.int64))
        mean: np.ndarray = np.mean(output, axis=0)  # type: ignore
        std: np.ndarray = np.std(output, axis=0)  # type: ignore
        vvc: float = (np.std(output) / np.mean(output)).item()
        return mode, mean, std, vvc