def test_dict_examples(self): test_case = { "meta": { "out": ["test", "test"] }, PostFix.meta("image"): { "scl_slope": torch.Tensor((0.0, 0.0)) } } out = decollate_batch(test_case) self.assertEqual(out[0]["meta"]["out"], "test") self.assertEqual(out[0][PostFix.meta("image")]["scl_slope"], 0.0) test_case = [torch.ones((2, 1, 10, 10)), torch.ones((2, 3, 5, 5))] out = decollate_batch(test_case) self.assertTupleEqual(out[0][0].shape, (1, 10, 10)) self.assertTupleEqual(out[0][1].shape, (3, 5, 5)) test_case = torch.rand((2, 1, 10, 10)) out = decollate_batch(test_case) self.assertTupleEqual(out[0].shape, (1, 10, 10)) test_case = [torch.tensor(0), torch.tensor(0)] out = decollate_batch(test_case, detach=True) self.assertListEqual([0, 0], out) self.assertFalse(isinstance(out[0], torch.Tensor)) test_case = {"a": [torch.tensor(0), torch.tensor(0)]} out = decollate_batch(test_case, detach=False) self.assertListEqual([{ "a": torch.tensor(0) }, { "a": torch.tensor(0) }], out) self.assertTrue(isinstance(out[0]["a"], torch.Tensor)) test_case = [torch.tensor(0), torch.tensor(0)] out = decollate_batch(test_case, detach=False) self.assertListEqual(test_case, out) test_case = { "image": torch.tensor([[[1, 2]], [[3, 4]]]), "label": torch.tensor([[[5, 6]], [[7, 8]]]), "pred": torch.tensor([[[9, 10]], [[11, 12]]]), "out": ["test"], } out = decollate_batch(test_case, detach=False) self.assertEqual(out[0]["out"], "test") test_case = { "image": torch.tensor([[[1, 2, 3]], [[3, 4, 5]]]), "label": torch.tensor([[[5]], [[7]]]), "out": ["test"], } out = decollate_batch(test_case, detach=False, pad=False) self.assertEqual(len(out), 1) # no padding out = decollate_batch(test_case, detach=False, pad=True, fill_value=0) self.assertEqual(out[1]["out"], 0) # verify padding fill_value
def test_decollation(self, *transforms): batch_size = 2 num_workers = 2 t_compose = Compose( [AddChanneld(KEYS), Compose(transforms), ToTensord(KEYS)]) # If nibabel present, read from disk if has_nib: t_compose = Compose([LoadImaged("image"), t_compose]) dataset = CacheDataset(self.data, t_compose, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) for b, batch_data in enumerate(loader): decollated_1 = decollate_batch(batch_data) decollated_2 = Decollated()(batch_data) for decollated in [decollated_1, decollated_2]: for i, d in enumerate(decollated): self.check_match(dataset[b * batch_size + i], d)
def test_decollation(self, batch_size=2, num_workers=2): im = create_test_image_2d(100, 101)[0] data = [{ "image": make_nifti_image(im) if has_nib else im } for _ in range(6)] transforms = Compose([ AddChanneld("image"), SpatialPadd("image", 150), RandFlipd("image", prob=1.0, spatial_axis=1), ToTensord("image"), ]) # If nibabel present, read from disk if has_nib: transforms = Compose([LoadImaged("image"), transforms]) dataset = CacheDataset(data, transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) for b, batch_data in enumerate(loader): decollated_1 = decollate_batch(batch_data) decollated_2 = Decollated()(batch_data) for decollated in [decollated_1, decollated_2]: for i, d in enumerate(decollated): self.check_match(dataset[b * batch_size + i], d)
def test_pad_collation(self, t_type, collate_method, transform): if t_type == dict: dataset = CacheDataset(self.dict_data, transform, progress=False) else: dataset = _Dataset(self.list_data, self.list_labels, transform) # Default collation should raise an error loader_fail = DataLoader(dataset, batch_size=10) with self.assertRaises(RuntimeError): for _ in loader_fail: pass # Padded collation shouldn't loader = DataLoader(dataset, batch_size=10, collate_fn=collate_method) # check collation in forward direction for data in loader: if t_type == dict: shapes = [] decollated_data = decollate_batch(data) for d in decollated_data: output = PadListDataCollate.inverse(d) shapes.append(output["image"].shape) self.assertTrue( len(set(shapes)) > 1 ) # inverted shapes must be different because of random xforms
def test_inverse_inferred_seg(self): test_data = [] for _ in range(20): image, label = create_test_image_2d(100, 101) test_data.append({ "image": image, "label": label.astype(np.float32) }) batch_size = 10 # num workers = 0 for mac num_workers = 2 if sys.platform != "darwin" else 0 transforms = Compose([ AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), CenterSpatialCropd(KEYS, (110, 99)) ]) num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform)) dataset = CacheDataset(test_data, transform=transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) device = "cuda" if torch.cuda.is_available() else "cpu" model = UNet( dimensions=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(2, ), ).to(device) data = first(loader) labels = data["label"].to(device) segs = model(labels).detach().cpu() label_transform_key = "label" + InverseKeys.KEY_SUFFIX.value segs_dict = { "label": segs, label_transform_key: data[label_transform_key] } segs_dict_decollated = decollate_batch(segs_dict) # inverse of individual segmentation seg_dict = first(segs_dict_decollated) with allow_missing_keys_mode(transforms): inv_seg = transforms.inverse(seg_dict)["label"] self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms) self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)
def __call__( self, data: Dict[str, Any], num_examples: int = 10 ) -> Union[Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor, float], NdarrayOrTensor]: """ Args: data: dictionary data to be processed. num_examples: number of realisations to be processed and results combined. Returns: - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are calculated across `num_examples` outputs at each voxel. The volume variation coefficient (VVC) is `std/mean` across the whole output, including `num_examples`. See original paper for clarification. - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then concatenating across the first dimension containing `num_examples`. This allows the user to perform their own analysis if desired. """ d = dict(data) # check num examples is multiple of batch size if num_examples % self.batch_size != 0: raise ValueError("num_examples should be multiple of batch size.") # generate batch of data of size == batch_size, dataset and dataloader data_in = [deepcopy(d) for _ in range(num_examples)] ds = Dataset(data_in, self.transform) dl = DataLoader(ds, num_workers=self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate) outs: List = [] for b in tqdm(dl) if has_tqdm and self.progress else dl: # do model forward pass b[self._pred_key] = self.inferrer_fn(b[self.image_key].to( self.device)) outs.extend([ self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b) ]) output: NdarrayOrTensor = stack(outs, 0) if self.return_full_data: return output # calculate metrics _mode = mode(output, dim=0) mean = output.mean(0) std = output.std(0) vvc = (output.std() / output.mean()).item() return _mode, mean, std, vvc
def test_decollate(self, dtype): batch_size = 3 ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)] ds = Dataset(ims) dl = DataLoader(ds, num_workers=batch_size, batch_size=batch_size) batch = next(iter(dl)) decollated = decollate_batch(batch) self.assertIsInstance(decollated, list) self.assertEqual(len(decollated), batch_size) for elem, im in zip(decollated, ims): self.assertIsInstance(elem, MetaTensor) self.check(elem, im, ids=False)
def check_decollate(self, dataset): batch_size = 2 num_workers = 2 loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) for b, batch_data in enumerate(loader): decollated_1 = decollate_batch(batch_data) decollated_2 = Decollated(detach=True)(batch_data) for decollated in [decollated_1, decollated_2]: for i, d in enumerate(decollated): self.check_match(dataset[b * batch_size + i], d)
def __call__(self, data: Dict[str, Any]) -> Any: decollated_data = decollate_batch(data, detach=self.detach, pad=self.pad_batch, fill_value=self.fill_value) inv_ds = _BatchInverseDataset(decollated_data, self.transform, self.pad_collation_used) inv_loader = DataLoader( inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn ) try: return first(inv_loader) except RuntimeError as re: re_str = str(re) if "equal size" in re_str: re_str += "\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`." raise RuntimeError(re_str) from re
def __call__(self, data: Union[Dict, List]): d: Union[Dict, List] if len(self.keys) == 1 and self.keys[0] is None: # it doesn't support `None` as the key d = data else: if not isinstance(data, dict): raise TypeError("input data is not a dictionary, but specified keys to decollate.") d = {} for key in self.key_iterator(data): d[key] = data[key] return decollate_batch(d, detach=self.detach, pad=self.pad_batch, fill_value=self.fill_value)
def test_dict_examples(self): test_case = { "meta": { "out": ["test", "test"] }, "image_meta_dict": { "scl_slope": torch.Tensor((0.0, 0.0)) } } out = decollate_batch(test_case) self.assertEqual(out[0]["meta"]["out"], "test") self.assertEqual(out[0]["image_meta_dict"]["scl_slope"], 0.0) test_case = [torch.ones((2, 1, 10, 10)), torch.ones((2, 3, 5, 5))] out = decollate_batch(test_case) self.assertTupleEqual(out[0][0].shape, (1, 10, 10)) self.assertTupleEqual(out[0][1].shape, (3, 5, 5)) test_case = torch.rand((2, 1, 10, 10)) out = decollate_batch(test_case) self.assertTupleEqual(out[0].shape, (1, 10, 10)) test_case = [torch.tensor(0), torch.tensor(0)] out = decollate_batch(test_case, detach=True) self.assertListEqual([0, 0], out) self.assertFalse(isinstance(out[0], torch.Tensor)) test_case = {"a": [torch.tensor(0), torch.tensor(0)]} out = decollate_batch(test_case, detach=False) self.assertListEqual([{ "a": torch.tensor(0) }, { "a": torch.tensor(0) }], out) self.assertTrue(isinstance(out[0]["a"], torch.Tensor)) test_case = [torch.tensor(0), torch.tensor(0)] out = decollate_batch(test_case, detach=False) self.assertListEqual(test_case, out) test_case = { "image": torch.tensor([[[1, 2]], [[3, 4]]]), "label": torch.tensor([[[5, 6]], [[7, 8]]]), "pred": torch.tensor([[[9, 10]], [[11, 12]]]), "out": ["test"], } out = decollate_batch(test_case, detach=False) self.assertEqual(out[0]["out"], "test")
def test_inverse_inferred_seg(self, extra_transform): test_data = [] for _ in range(20): image, label = create_test_image_2d(100, 101) test_data.append({"image": image, "label": label.astype(np.float32)}) batch_size = 10 # num workers = 0 for mac num_workers = 2 if sys.platform == "linux" else 0 transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), extra_transform]) num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform)) dataset = CacheDataset(test_data, transform=transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) device = "cuda" if torch.cuda.is_available() else "cpu" model = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(2,)).to(device) data = first(loader) self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) self.assertEqual(data["image"].shape[0], batch_size * NUM_SAMPLES) labels = data["label"].to(device) segs = model(labels).detach().cpu() label_transform_key = "label" + InverseKeys.KEY_SUFFIX segs_dict = {"label": segs, label_transform_key: data[label_transform_key]} segs_dict_decollated = decollate_batch(segs_dict) # inverse of individual segmentation seg_dict = first(segs_dict_decollated) # test to convert interpolation mode for 1 data of model output batch convert_inverse_interp_mode(seg_dict, mode="nearest", align_corners=None) with allow_missing_keys_mode(transforms): inv_seg = transforms.inverse(seg_dict)["label"] self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms) self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape) # Inverse of batch batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=no_collation, detach=True) with allow_missing_keys_mode(transforms): inv_batch = batch_inverter(segs_dict) self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape)
def test_pad_collation(self, t_type, collate_method, transform): if t_type == dict: dataset = CacheDataset(self.dict_data, transform, progress=False) else: dataset = _Dataset(self.list_data, self.list_labels, transform) # Default collation should raise an error loader_fail = DataLoader(dataset, batch_size=10) with self.assertRaises(RuntimeError): for _ in loader_fail: pass # Padded collation shouldn't loader = DataLoader(dataset, batch_size=10, collate_fn=collate_method) # check collation in forward direction for data in loader: if t_type == dict: decollated_data = decollate_batch(data) for d in decollated_data: PadListDataCollate.inverse(d)
def update_meta(rets: Sequence, func, args, kwargs): """Update the metadata from the output of `__torch_function__`. The output could be a single object, or a sequence of them. Hence, they get converted to a sequence if necessary and then processed by iterating across them. For each element, if not of type `MetaTensor`, then nothing to do """ out = [] metas = None for idx, ret in enumerate(rets): # if not `MetaTensor`, nothing to do. if not isinstance(ret, MetaTensor): pass # if not tracking, convert to `torch.Tensor`. elif not (get_track_meta() or get_track_transforms()): ret = ret.as_tensor() # else, handle the `MetaTensor` metadata. else: meta_args = MetaObj.flatten_meta_objs( list(args) + list(kwargs.values())) ret._copy_meta(meta_args) # If we have a batch of data, then we need to be careful if a slice of # the data is returned. Depending on how the data are indexed, we return # some or all of the metadata, and the return object may or may not be a # batch of data (e.g., `batch[:,-1]` versus `batch[0]`). if ret.is_batch: # only decollate metadata once if metas is None: metas = decollate_batch(ret.meta) # if indexing e.g., `batch[0]` if func == torch.Tensor.__getitem__: idx = args[1] if isinstance(idx, Sequence): idx = idx[0] # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the # first element will be `slice(None, None, None)` and `Ellipsis`, # respectively. Don't need to do anything with the metadata. if idx not in (slice(None, None, None), Ellipsis): meta = metas[idx] # if using e.g., `batch[0:2]`, then `is_batch` should still be # `True`. Also re-collate the remaining elements. if isinstance(meta, list) and len(meta) > 1: ret.meta = list_data_collate(meta) # if using e.g., `batch[0]` or `batch[0, 1]`, then return single # element from batch, and set `is_batch` to `False`. else: ret.meta = meta ret.is_batch = False # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. # But we only want to split the batch if the `unbind` is along the 0th # dimension. elif func == torch.Tensor.unbind: if len(args) > 1: dim = args[1] elif "dim" in kwargs: dim = kwargs["dim"] else: dim = 0 if dim == 0: ret.meta = metas[idx] ret.is_batch = False ret.affine = ret.affine.to(ret.device) out.append(ret) # if the input was a tuple, then return it as a tuple return tuple(out) if isinstance(rets, tuple) else out
def test_decollation_examples(self, input_val, expected_out): out = decollate_batch(input_val) self.assertListEqual(expected_out, out)
def __call__(self, data: dict): return decollate_batch(data, detach=self.detach)
def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: """ Update the metadata from the output of `MetaTensor.__torch_function__`. The output of `torch.Tensor.__torch_function__` could be a single object or a sequence of them. Hence, in `MetaTensor.__torch_function__` we convert them to a list of not already, and then we loop across each element, processing metadata as necessary. For each element, if not of type `MetaTensor`, then nothing to do. Args: rets: the output from `torch.Tensor.__torch_function__`, which has been converted to a list in `MetaTensor.__torch_function__` if it wasn't already a `Sequence`. func: the torch function that was applied. Examples might be `torch.squeeze` or `torch.Tensor.__add__`. We need this since the metadata need to be treated differently if a batch of data is considered. For example, slicing (`torch.Tensor.__getitem__`) the ith element of the 0th dimension of a batch of data should return a ith tensor with the ith metadata. args: positional arguments that were passed to `func`. kwargs: keyword arguments that were passed to `func`. Returns: A sequence with the same number of elements as `rets`. For each element, if the input type was not `MetaTensor`, then no modifications will have been made. If global parameters have been set to false (e.g., `not get_track_meta()`), then any `MetaTensor` will be converted to `torch.Tensor`. Else, metadata will be propagated as necessary (see :py:func:`MetaTensor._copy_meta`). """ out = [] metas = None is_batch = any( x.is_batch for x in MetaObj.flatten_meta_objs(args, kwargs.values()) if hasattr(x, "is_batch")) for idx, ret in enumerate(rets): # if not `MetaTensor`, nothing to do. if not isinstance(ret, MetaTensor): pass # if not tracking, convert to `torch.Tensor`. elif not get_track_meta(): ret = ret.as_tensor() # else, handle the `MetaTensor` metadata. else: meta_args = MetaObj.flatten_meta_objs(args, kwargs.values()) ret.is_batch = is_batch ret.copy_meta_from(meta_args, copy_attr=not is_batch) # the following is not implemented but the network arch may run into this case: # if func == torch.cat and any(m.is_batch if hasattr(m, "is_batch") else False for m in meta_args): # raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.") # If we have a batch of data, then we need to be careful if a slice of # the data is returned. Depending on how the data are indexed, we return # some or all of the metadata, and the return object may or may not be a # batch of data (e.g., `batch[:,-1]` versus `batch[0]`). if is_batch: # if indexing e.g., `batch[0]` if func == torch.Tensor.__getitem__: batch_idx = args[1] if isinstance(batch_idx, Sequence): batch_idx = batch_idx[0] # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the # first element will be `slice(None, None, None)` and `Ellipsis`, # respectively. Don't need to do anything with the metadata. if batch_idx not in (slice(None, None, None), Ellipsis, None) and idx == 0: ret_meta = decollate_batch(args[0], detach=False)[batch_idx] if isinstance(ret_meta, list): # e.g. batch[0:2], re-collate ret_meta = list_data_collate(ret_meta) else: # e.g. `batch[0]` or `batch[0, 1]`, batch index is an integer ret_meta.is_batch = False ret.__dict__ = ret_meta.__dict__.copy() # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. # But we only want to split the batch if the `unbind` is along the 0th # dimension. elif func == torch.Tensor.unbind: if len(args) > 1: dim = args[1] elif "dim" in kwargs: dim = kwargs["dim"] else: dim = 0 if dim == 0: if metas is None: metas = decollate_batch(args[0], detach=False) ret.__dict__ = metas[idx].__dict__.copy() ret.is_batch = False out.append(ret) # if the input was a tuple, then return it as a tuple return tuple(out) if isinstance(rets, tuple) else out
def __call__(self, data: dict) -> List[dict]: return decollate_batch(data, self.batch_size)