Example #1
0
    def test_dict_examples(self):
        test_case = {
            "meta": {
                "out": ["test", "test"]
            },
            PostFix.meta("image"): {
                "scl_slope": torch.Tensor((0.0, 0.0))
            }
        }
        out = decollate_batch(test_case)
        self.assertEqual(out[0]["meta"]["out"], "test")
        self.assertEqual(out[0][PostFix.meta("image")]["scl_slope"], 0.0)

        test_case = [torch.ones((2, 1, 10, 10)), torch.ones((2, 3, 5, 5))]
        out = decollate_batch(test_case)
        self.assertTupleEqual(out[0][0].shape, (1, 10, 10))
        self.assertTupleEqual(out[0][1].shape, (3, 5, 5))

        test_case = torch.rand((2, 1, 10, 10))
        out = decollate_batch(test_case)
        self.assertTupleEqual(out[0].shape, (1, 10, 10))

        test_case = [torch.tensor(0), torch.tensor(0)]
        out = decollate_batch(test_case, detach=True)
        self.assertListEqual([0, 0], out)
        self.assertFalse(isinstance(out[0], torch.Tensor))

        test_case = {"a": [torch.tensor(0), torch.tensor(0)]}
        out = decollate_batch(test_case, detach=False)
        self.assertListEqual([{
            "a": torch.tensor(0)
        }, {
            "a": torch.tensor(0)
        }], out)
        self.assertTrue(isinstance(out[0]["a"], torch.Tensor))

        test_case = [torch.tensor(0), torch.tensor(0)]
        out = decollate_batch(test_case, detach=False)
        self.assertListEqual(test_case, out)

        test_case = {
            "image": torch.tensor([[[1, 2]], [[3, 4]]]),
            "label": torch.tensor([[[5, 6]], [[7, 8]]]),
            "pred": torch.tensor([[[9, 10]], [[11, 12]]]),
            "out": ["test"],
        }
        out = decollate_batch(test_case, detach=False)
        self.assertEqual(out[0]["out"], "test")

        test_case = {
            "image": torch.tensor([[[1, 2, 3]], [[3, 4, 5]]]),
            "label": torch.tensor([[[5]], [[7]]]),
            "out": ["test"],
        }
        out = decollate_batch(test_case, detach=False, pad=False)
        self.assertEqual(len(out), 1)  # no padding
        out = decollate_batch(test_case, detach=False, pad=True, fill_value=0)
        self.assertEqual(out[1]["out"], 0)  # verify padding fill_value
    def test_decollation(self, *transforms):

        batch_size = 2
        num_workers = 2

        t_compose = Compose(
            [AddChanneld(KEYS),
             Compose(transforms),
             ToTensord(KEYS)])
        # If nibabel present, read from disk
        if has_nib:
            t_compose = Compose([LoadImaged("image"), t_compose])

        dataset = CacheDataset(self.data, t_compose, progress=False)
        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers)

        for b, batch_data in enumerate(loader):
            decollated_1 = decollate_batch(batch_data)
            decollated_2 = Decollated()(batch_data)

            for decollated in [decollated_1, decollated_2]:
                for i, d in enumerate(decollated):
                    self.check_match(dataset[b * batch_size + i], d)
Example #3
0
    def test_decollation(self, batch_size=2, num_workers=2):

        im = create_test_image_2d(100, 101)[0]
        data = [{
            "image": make_nifti_image(im) if has_nib else im
        } for _ in range(6)]

        transforms = Compose([
            AddChanneld("image"),
            SpatialPadd("image", 150),
            RandFlipd("image", prob=1.0, spatial_axis=1),
            ToTensord("image"),
        ])
        # If nibabel present, read from disk
        if has_nib:
            transforms = Compose([LoadImaged("image"), transforms])

        dataset = CacheDataset(data, transforms, progress=False)
        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers)

        for b, batch_data in enumerate(loader):
            decollated_1 = decollate_batch(batch_data)
            decollated_2 = Decollated()(batch_data)

            for decollated in [decollated_1, decollated_2]:
                for i, d in enumerate(decollated):
                    self.check_match(dataset[b * batch_size + i], d)
Example #4
0
    def test_pad_collation(self, t_type, collate_method, transform):

        if t_type == dict:
            dataset = CacheDataset(self.dict_data, transform, progress=False)
        else:
            dataset = _Dataset(self.list_data, self.list_labels, transform)

        # Default collation should raise an error
        loader_fail = DataLoader(dataset, batch_size=10)
        with self.assertRaises(RuntimeError):
            for _ in loader_fail:
                pass

        # Padded collation shouldn't
        loader = DataLoader(dataset, batch_size=10, collate_fn=collate_method)
        # check collation in forward direction
        for data in loader:
            if t_type == dict:
                shapes = []
                decollated_data = decollate_batch(data)
                for d in decollated_data:
                    output = PadListDataCollate.inverse(d)
                    shapes.append(output["image"].shape)
                self.assertTrue(
                    len(set(shapes)) > 1
                )  # inverted shapes must be different because of random xforms
Example #5
0
    def test_inverse_inferred_seg(self):

        test_data = []
        for _ in range(20):
            image, label = create_test_image_2d(100, 101)
            test_data.append({
                "image": image,
                "label": label.astype(np.float32)
            })

        batch_size = 10
        # num workers = 0 for mac
        num_workers = 2 if sys.platform != "darwin" else 0
        transforms = Compose([
            AddChanneld(KEYS),
            SpatialPadd(KEYS, (150, 153)),
            CenterSpatialCropd(KEYS, (110, 99))
        ])
        num_invertible_transforms = sum(1 for i in transforms.transforms
                                        if isinstance(i, InvertibleTransform))

        dataset = CacheDataset(test_data, transform=transforms, progress=False)
        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers)

        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = UNet(
            dimensions=2,
            in_channels=1,
            out_channels=1,
            channels=(2, 4),
            strides=(2, ),
        ).to(device)

        data = first(loader)
        labels = data["label"].to(device)
        segs = model(labels).detach().cpu()
        label_transform_key = "label" + InverseKeys.KEY_SUFFIX.value
        segs_dict = {
            "label": segs,
            label_transform_key: data[label_transform_key]
        }

        segs_dict_decollated = decollate_batch(segs_dict)

        # inverse of individual segmentation
        seg_dict = first(segs_dict_decollated)
        with allow_missing_keys_mode(transforms):
            inv_seg = transforms.inverse(seg_dict)["label"]
        self.assertEqual(len(data["label_transforms"]),
                         num_invertible_transforms)
        self.assertEqual(len(seg_dict["label_transforms"]),
                         num_invertible_transforms)
        self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)
Example #6
0
    def __call__(
        self,
        data: Dict[str, Any],
        num_examples: int = 10
    ) -> Union[Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor, float],
               NdarrayOrTensor]:
        """
        Args:
            data: dictionary data to be processed.
            num_examples: number of realisations to be processed and results combined.

        Returns:
            - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are
                calculated across `num_examples` outputs at each voxel. The volume variation coefficient (VVC)
                is `std/mean` across the whole output, including `num_examples`. See original paper for clarification.
            - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then
                concatenating across the first dimension containing `num_examples`. This allows the user to perform
                their own analysis if desired.
        """
        d = dict(data)

        # check num examples is multiple of batch size
        if num_examples % self.batch_size != 0:
            raise ValueError("num_examples should be multiple of batch size.")

        # generate batch of data of size == batch_size, dataset and dataloader
        data_in = [deepcopy(d) for _ in range(num_examples)]
        ds = Dataset(data_in, self.transform)
        dl = DataLoader(ds,
                        num_workers=self.num_workers,
                        batch_size=self.batch_size,
                        collate_fn=pad_list_data_collate)

        outs: List = []

        for b in tqdm(dl) if has_tqdm and self.progress else dl:
            # do model forward pass
            b[self._pred_key] = self.inferrer_fn(b[self.image_key].to(
                self.device))
            outs.extend([
                self.inverter(PadListDataCollate.inverse(i))[self._pred_key]
                for i in decollate_batch(b)
            ])

        output: NdarrayOrTensor = stack(outs, 0)

        if self.return_full_data:
            return output

        # calculate metrics
        _mode = mode(output, dim=0)
        mean = output.mean(0)
        std = output.std(0)
        vvc = (output.std() / output.mean()).item()

        return _mode, mean, std, vvc
Example #7
0
 def test_decollate(self, dtype):
     batch_size = 3
     ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)]
     ds = Dataset(ims)
     dl = DataLoader(ds, num_workers=batch_size, batch_size=batch_size)
     batch = next(iter(dl))
     decollated = decollate_batch(batch)
     self.assertIsInstance(decollated, list)
     self.assertEqual(len(decollated), batch_size)
     for elem, im in zip(decollated, ims):
         self.assertIsInstance(elem, MetaTensor)
         self.check(elem, im, ids=False)
Example #8
0
    def check_decollate(self, dataset):
        batch_size = 2
        num_workers = 2

        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        for b, batch_data in enumerate(loader):
            decollated_1 = decollate_batch(batch_data)
            decollated_2 = Decollated(detach=True)(batch_data)

            for decollated in [decollated_1, decollated_2]:
                for i, d in enumerate(decollated):
                    self.check_match(dataset[b * batch_size + i], d)
Example #9
0
 def __call__(self, data: Dict[str, Any]) -> Any:
     decollated_data = decollate_batch(data, detach=self.detach, pad=self.pad_batch, fill_value=self.fill_value)
     inv_ds = _BatchInverseDataset(decollated_data, self.transform, self.pad_collation_used)
     inv_loader = DataLoader(
         inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn
     )
     try:
         return first(inv_loader)
     except RuntimeError as re:
         re_str = str(re)
         if "equal size" in re_str:
             re_str += "\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`."
         raise RuntimeError(re_str) from re
Example #10
0
    def __call__(self, data: Union[Dict, List]):
        d: Union[Dict, List]
        if len(self.keys) == 1 and self.keys[0] is None:
            # it doesn't support `None` as the key
            d = data
        else:
            if not isinstance(data, dict):
                raise TypeError("input data is not a dictionary, but specified keys to decollate.")
            d = {}
            for key in self.key_iterator(data):
                d[key] = data[key]

        return decollate_batch(d, detach=self.detach, pad=self.pad_batch, fill_value=self.fill_value)
Example #11
0
    def test_dict_examples(self):
        test_case = {
            "meta": {
                "out": ["test", "test"]
            },
            "image_meta_dict": {
                "scl_slope": torch.Tensor((0.0, 0.0))
            }
        }
        out = decollate_batch(test_case)
        self.assertEqual(out[0]["meta"]["out"], "test")
        self.assertEqual(out[0]["image_meta_dict"]["scl_slope"], 0.0)

        test_case = [torch.ones((2, 1, 10, 10)), torch.ones((2, 3, 5, 5))]
        out = decollate_batch(test_case)
        self.assertTupleEqual(out[0][0].shape, (1, 10, 10))
        self.assertTupleEqual(out[0][1].shape, (3, 5, 5))

        test_case = torch.rand((2, 1, 10, 10))
        out = decollate_batch(test_case)
        self.assertTupleEqual(out[0].shape, (1, 10, 10))

        test_case = [torch.tensor(0), torch.tensor(0)]
        out = decollate_batch(test_case, detach=True)
        self.assertListEqual([0, 0], out)
        self.assertFalse(isinstance(out[0], torch.Tensor))

        test_case = {"a": [torch.tensor(0), torch.tensor(0)]}
        out = decollate_batch(test_case, detach=False)
        self.assertListEqual([{
            "a": torch.tensor(0)
        }, {
            "a": torch.tensor(0)
        }], out)
        self.assertTrue(isinstance(out[0]["a"], torch.Tensor))

        test_case = [torch.tensor(0), torch.tensor(0)]
        out = decollate_batch(test_case, detach=False)
        self.assertListEqual(test_case, out)

        test_case = {
            "image": torch.tensor([[[1, 2]], [[3, 4]]]),
            "label": torch.tensor([[[5, 6]], [[7, 8]]]),
            "pred": torch.tensor([[[9, 10]], [[11, 12]]]),
            "out": ["test"],
        }
        out = decollate_batch(test_case, detach=False)
        self.assertEqual(out[0]["out"], "test")
Example #12
0
    def test_inverse_inferred_seg(self, extra_transform):

        test_data = []
        for _ in range(20):
            image, label = create_test_image_2d(100, 101)
            test_data.append({"image": image, "label": label.astype(np.float32)})

        batch_size = 10
        # num workers = 0 for mac
        num_workers = 2 if sys.platform == "linux" else 0
        transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), extra_transform])
        num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform))

        dataset = CacheDataset(test_data, transform=transforms, progress=False)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(2,)).to(device)

        data = first(loader)
        self.assertEqual(len(data["label_transforms"]), num_invertible_transforms)
        self.assertEqual(data["image"].shape[0], batch_size * NUM_SAMPLES)

        labels = data["label"].to(device)
        segs = model(labels).detach().cpu()
        label_transform_key = "label" + InverseKeys.KEY_SUFFIX
        segs_dict = {"label": segs, label_transform_key: data[label_transform_key]}

        segs_dict_decollated = decollate_batch(segs_dict)
        # inverse of individual segmentation
        seg_dict = first(segs_dict_decollated)
        # test to convert interpolation mode for 1 data of model output batch
        convert_inverse_interp_mode(seg_dict, mode="nearest", align_corners=None)

        with allow_missing_keys_mode(transforms):
            inv_seg = transforms.inverse(seg_dict)["label"]
        self.assertEqual(len(data["label_transforms"]), num_invertible_transforms)
        self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms)
        self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)

        # Inverse of batch
        batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=no_collation, detach=True)
        with allow_missing_keys_mode(transforms):
            inv_batch = batch_inverter(segs_dict)
        self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape)
Example #13
0
    def test_pad_collation(self, t_type, collate_method, transform):

        if t_type == dict:
            dataset = CacheDataset(self.dict_data, transform, progress=False)
        else:
            dataset = _Dataset(self.list_data, self.list_labels, transform)

        # Default collation should raise an error
        loader_fail = DataLoader(dataset, batch_size=10)
        with self.assertRaises(RuntimeError):
            for _ in loader_fail:
                pass

        # Padded collation shouldn't
        loader = DataLoader(dataset, batch_size=10, collate_fn=collate_method)
        # check collation in forward direction
        for data in loader:
            if t_type == dict:
                decollated_data = decollate_batch(data)
                for d in decollated_data:
                    PadListDataCollate.inverse(d)
Example #14
0
    def update_meta(rets: Sequence, func, args, kwargs):
        """Update the metadata from the output of `__torch_function__`.
        The output could be a single object, or a sequence of them. Hence, they get
        converted to a sequence if necessary and then processed by iterating across them.

        For each element, if not of type `MetaTensor`, then nothing to do
        """
        out = []
        metas = None
        for idx, ret in enumerate(rets):
            # if not `MetaTensor`, nothing to do.
            if not isinstance(ret, MetaTensor):
                pass
            # if not tracking, convert to `torch.Tensor`.
            elif not (get_track_meta() or get_track_transforms()):
                ret = ret.as_tensor()
            # else, handle the `MetaTensor` metadata.
            else:
                meta_args = MetaObj.flatten_meta_objs(
                    list(args) + list(kwargs.values()))
                ret._copy_meta(meta_args)

                # If we have a batch of data, then we need to be careful if a slice of
                # the data is returned. Depending on how the data are indexed, we return
                # some or all of the metadata, and the return object may or may not be a
                # batch of data (e.g., `batch[:,-1]` versus `batch[0]`).
                if ret.is_batch:
                    # only decollate metadata once
                    if metas is None:
                        metas = decollate_batch(ret.meta)
                    # if indexing e.g., `batch[0]`
                    if func == torch.Tensor.__getitem__:
                        idx = args[1]
                        if isinstance(idx, Sequence):
                            idx = idx[0]
                        # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
                        # first element will be `slice(None, None, None)` and `Ellipsis`,
                        # respectively. Don't need to do anything with the metadata.
                        if idx not in (slice(None, None, None), Ellipsis):
                            meta = metas[idx]
                            # if using e.g., `batch[0:2]`, then `is_batch` should still be
                            # `True`. Also re-collate the remaining elements.
                            if isinstance(meta, list) and len(meta) > 1:
                                ret.meta = list_data_collate(meta)
                            # if using e.g., `batch[0]` or `batch[0, 1]`, then return single
                            # element from batch, and set `is_batch` to `False`.
                            else:
                                ret.meta = meta
                                ret.is_batch = False
                    # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
                    # But we only want to split the batch if the `unbind` is along the 0th
                    # dimension.
                    elif func == torch.Tensor.unbind:
                        if len(args) > 1:
                            dim = args[1]
                        elif "dim" in kwargs:
                            dim = kwargs["dim"]
                        else:
                            dim = 0
                        if dim == 0:
                            ret.meta = metas[idx]
                            ret.is_batch = False

                ret.affine = ret.affine.to(ret.device)
            out.append(ret)
        # if the input was a tuple, then return it as a tuple
        return tuple(out) if isinstance(rets, tuple) else out
Example #15
0
 def test_decollation_examples(self, input_val, expected_out):
     out = decollate_batch(input_val)
     self.assertListEqual(expected_out, out)
Example #16
0
 def __call__(self, data: dict):
     return decollate_batch(data, detach=self.detach)
Example #17
0
    def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
        """
        Update the metadata from the output of `MetaTensor.__torch_function__`.

        The output of `torch.Tensor.__torch_function__` could be a single object or a
        sequence of them. Hence, in `MetaTensor.__torch_function__` we convert them to a
        list of not already, and then we loop across each element, processing metadata
        as necessary. For each element, if not of type `MetaTensor`, then nothing to do.

        Args:
            rets: the output from `torch.Tensor.__torch_function__`, which has been
                converted to a list in `MetaTensor.__torch_function__` if it wasn't
                already a `Sequence`.
            func: the torch function that was applied. Examples might be `torch.squeeze`
                or `torch.Tensor.__add__`. We need this since the metadata need to be
                treated differently if a batch of data is considered. For example,
                slicing (`torch.Tensor.__getitem__`) the ith element of the 0th
                dimension of a batch of data should return a ith tensor with the ith
                metadata.
            args: positional arguments that were passed to `func`.
            kwargs: keyword arguments that were passed to `func`.

        Returns:
            A sequence with the same number of elements as `rets`. For each element, if
            the input type was not `MetaTensor`, then no modifications will have been
            made. If global parameters have been set to false (e.g.,
            `not get_track_meta()`), then any `MetaTensor` will be converted to
            `torch.Tensor`. Else, metadata will be propagated as necessary (see
            :py:func:`MetaTensor._copy_meta`).
        """
        out = []
        metas = None
        is_batch = any(
            x.is_batch
            for x in MetaObj.flatten_meta_objs(args, kwargs.values())
            if hasattr(x, "is_batch"))
        for idx, ret in enumerate(rets):
            # if not `MetaTensor`, nothing to do.
            if not isinstance(ret, MetaTensor):
                pass
            # if not tracking, convert to `torch.Tensor`.
            elif not get_track_meta():
                ret = ret.as_tensor()
            # else, handle the `MetaTensor` metadata.
            else:
                meta_args = MetaObj.flatten_meta_objs(args, kwargs.values())
                ret.is_batch = is_batch
                ret.copy_meta_from(meta_args, copy_attr=not is_batch)
                # the following is not implemented but the network arch may run into this case:
                # if func == torch.cat and any(m.is_batch if hasattr(m, "is_batch") else False for m in meta_args):
                #     raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.")

                # If we have a batch of data, then we need to be careful if a slice of
                # the data is returned. Depending on how the data are indexed, we return
                # some or all of the metadata, and the return object may or may not be a
                # batch of data (e.g., `batch[:,-1]` versus `batch[0]`).
                if is_batch:
                    # if indexing e.g., `batch[0]`
                    if func == torch.Tensor.__getitem__:
                        batch_idx = args[1]
                        if isinstance(batch_idx, Sequence):
                            batch_idx = batch_idx[0]
                        # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
                        # first element will be `slice(None, None, None)` and `Ellipsis`,
                        # respectively. Don't need to do anything with the metadata.
                        if batch_idx not in (slice(None, None, None), Ellipsis,
                                             None) and idx == 0:
                            ret_meta = decollate_batch(args[0],
                                                       detach=False)[batch_idx]
                            if isinstance(ret_meta,
                                          list):  # e.g. batch[0:2], re-collate
                                ret_meta = list_data_collate(ret_meta)
                            else:  # e.g. `batch[0]` or `batch[0, 1]`, batch index is an integer
                                ret_meta.is_batch = False
                            ret.__dict__ = ret_meta.__dict__.copy()
                    # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
                    # But we only want to split the batch if the `unbind` is along the 0th
                    # dimension.
                    elif func == torch.Tensor.unbind:
                        if len(args) > 1:
                            dim = args[1]
                        elif "dim" in kwargs:
                            dim = kwargs["dim"]
                        else:
                            dim = 0
                        if dim == 0:
                            if metas is None:
                                metas = decollate_batch(args[0], detach=False)
                            ret.__dict__ = metas[idx].__dict__.copy()
                            ret.is_batch = False

            out.append(ret)
        # if the input was a tuple, then return it as a tuple
        return tuple(out) if isinstance(rets, tuple) else out
Example #18
0
 def __call__(self, data: dict) -> List[dict]:
     return decollate_batch(data, self.batch_size)