def test_type_shape(self, input_data, expected_type, expected_shape): result = list_data_collate(input_data) self.assertIsInstance(result, expected_type) if isinstance(result, dict): data = result['image'] else: data = result[0] self.assertEqual(data.shape, expected_shape)
def __call__(self, batch: Any): """ Args: batch: batch of data to pad-collate """ # data is either list of dicts or list of lists is_list_of_dicts = isinstance(batch[0], dict) # loop over items inside of each element in a batch for key_or_idx in batch[0].keys() if is_list_of_dicts else range( len(batch[0])): # calculate max size of each dimension max_shapes = [] for elem in batch: if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)): break max_shapes.append(elem[key_or_idx].shape[1:]) # len > 0 if objects were arrays, else skip as no padding to be done if not max_shapes: continue max_shape = np.array(max_shapes).max(axis=0) # If all same size, skip if np.all(np.array(max_shapes).min(axis=0) == max_shape): continue # Do we need to convert output to Tensor? output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor) # Use `SpatialPadd` or `SpatialPad` to match sizes # Default params are central padding, padding with 0's # If input is dictionary, use the dictionary version so that the transformation is recorded padder = SpatialPad(spatial_size=max_shape, method=self.method, mode=self.mode, **self.np_kwargs) transform = padder if not output_to_tensor else Compose( [padder, ToTensor()]) for idx, batch_i in enumerate(batch): im = batch_i[key_or_idx] orig_size = im.shape[1:] padded = transform(batch_i[key_or_idx]) batch = replace_element(padded, batch, idx, key_or_idx) # If we have a dictionary of data, append to list if is_list_of_dicts: self.push_transform(batch[idx], key_or_idx, orig_size=orig_size) # After padding, use default list collator return list_data_collate(batch)
def test_collate(self, device, dtype): numel = 3 ims = [self.get_im(device=device, dtype=dtype)[0] for _ in range(numel)] collated = list_data_collate(ims) # tensor self.assertIsInstance(collated, MetaTensor) expected_shape = (numel,) + tuple(ims[0].shape) self.assertTupleEqual(tuple(collated.shape), expected_shape) for i, im in enumerate(ims): self.check(im, ims[i], ids=True) # affine self.assertIsInstance(collated.affine, torch.Tensor) expected_shape = (numel,) + tuple(ims[0].affine.shape) self.assertTupleEqual(tuple(collated.affine.shape), expected_shape)
def test_default_device(self, data_type): device = "cuda" if torch.cuda.is_available() else "cpu:0" inputs = data_type(torch.ones((3, 16, 15, 7))).to(device=device) inputs = list_data_collate([inputs]) # make a proper batch roi_shape = (4, 10, 7) sw_batch_size = 10 def compute(data): return data + 1 result = sliding_window_inference(inputs, roi_shape, sw_batch_size, compute) np.testing.assert_string_equal(inputs.device.type, result.device.type) expected_val = np.ones((1, 3, 16, 15, 7), dtype=np.float32) + 1 np.testing.assert_allclose(result.cpu().numpy(), expected_val)
def test_sw_device(self, data_type): inputs = data_type(torch.ones((3, 16, 15, 7))).to(device="cpu") inputs = list_data_collate([inputs]) # make a proper batch roi_shape = (4, 10, 7) sw_batch_size = 10 def compute(data): self.assertEqual(data.device.type, "cuda") return data + torch.tensor(1, device="cuda") result = sliding_window_inference(inputs, roi_shape, sw_batch_size, compute, sw_device="cuda") np.testing.assert_string_equal(inputs.device.type, result.device.type) expected_val = np.ones((1, 3, 16, 15, 7), dtype=np.float32) + 1 np.testing.assert_allclose(result.cpu().numpy(), expected_val)
def __call__(self, batch: Any): """ Args: batch: batch of data to pad-collate """ # data is either list of dicts or list of lists is_list_of_dicts = isinstance(batch[0], dict) # loop over items inside of each element in a batch batch_item = tuple(batch[0].keys()) if is_list_of_dicts else range( len(batch[0])) for key_or_idx in batch_item: # calculate max size of each dimension max_shapes = [] for elem in batch: if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)): break max_shapes.append(elem[key_or_idx].shape[1:]) # len > 0 if objects were arrays, else skip as no padding to be done if not max_shapes: continue max_shape = np.array(max_shapes).max(axis=0) # If all same size, skip if np.all(np.array(max_shapes).min(axis=0) == max_shape): continue # Use `SpatialPad` to match sizes, Default params are central padding, padding with 0's padder = SpatialPad(spatial_size=max_shape, method=self.method, mode=self.mode, **self.kwargs) for idx, batch_i in enumerate(batch): orig_size = batch_i[key_or_idx].shape[1:] padded = padder(batch_i[key_or_idx]) batch = replace_element(padded, batch, idx, key_or_idx) # If we have a dictionary of data, append to list if is_list_of_dicts: self.push_transform(batch[idx], key_or_idx, orig_size=orig_size) # After padding, use default list collator return list_data_collate(batch)
def update_meta(rets: Sequence, func, args, kwargs): """Update the metadata from the output of `__torch_function__`. The output could be a single object, or a sequence of them. Hence, they get converted to a sequence if necessary and then processed by iterating across them. For each element, if not of type `MetaTensor`, then nothing to do """ out = [] metas = None for idx, ret in enumerate(rets): # if not `MetaTensor`, nothing to do. if not isinstance(ret, MetaTensor): pass # if not tracking, convert to `torch.Tensor`. elif not (get_track_meta() or get_track_transforms()): ret = ret.as_tensor() # else, handle the `MetaTensor` metadata. else: meta_args = MetaObj.flatten_meta_objs( list(args) + list(kwargs.values())) ret._copy_meta(meta_args) # If we have a batch of data, then we need to be careful if a slice of # the data is returned. Depending on how the data are indexed, we return # some or all of the metadata, and the return object may or may not be a # batch of data (e.g., `batch[:,-1]` versus `batch[0]`). if ret.is_batch: # only decollate metadata once if metas is None: metas = decollate_batch(ret.meta) # if indexing e.g., `batch[0]` if func == torch.Tensor.__getitem__: idx = args[1] if isinstance(idx, Sequence): idx = idx[0] # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the # first element will be `slice(None, None, None)` and `Ellipsis`, # respectively. Don't need to do anything with the metadata. if idx not in (slice(None, None, None), Ellipsis): meta = metas[idx] # if using e.g., `batch[0:2]`, then `is_batch` should still be # `True`. Also re-collate the remaining elements. if isinstance(meta, list) and len(meta) > 1: ret.meta = list_data_collate(meta) # if using e.g., `batch[0]` or `batch[0, 1]`, then return single # element from batch, and set `is_batch` to `False`. else: ret.meta = meta ret.is_batch = False # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. # But we only want to split the batch if the `unbind` is along the 0th # dimension. elif func == torch.Tensor.unbind: if len(args) > 1: dim = args[1] elif "dim" in kwargs: dim = kwargs["dim"] else: dim = 0 if dim == 0: ret.meta = metas[idx] ret.is_batch = False ret.affine = ret.affine.to(ret.device) out.append(ret) # if the input was a tuple, then return it as a tuple return tuple(out) if isinstance(rets, tuple) else out
def test_indexing(self): """ Check the metadata is returned in the expected format depending on whether the input `MetaTensor` is a batch of data or not. """ ims = [self.get_im()[0] for _ in range(5)] data = list_data_collate(ims) # check that when using non-batch data, metadata is copied wholly when indexing # or iterating across data. im = ims[0] self.check_meta(im[0], im) self.check_meta(next(iter(im)), im) self.assertEqual(im[None].shape, (1, 1, 10, 8)) self.assertEqual(data[None].shape, (1, 5, 1, 10, 8)) # index d = data[0] self.check(d, ims[0], ids=False) # iter d = next(iter(data)) self.check(d, ims[0], ids=False) # complex indexing # `is_batch==True`, should have subset of image and metadata. d = data[1:3] self.check(d, list_data_collate(ims[1:3]), ids=False) # is_batch==True, should have subset of image and same metadata as `[1:3]`. d = data[1:3, 0] self.check(d, list_data_collate([i[0] for i in ims[1:3]]), ids=False) # `is_batch==False`, should have first metadata and subset of first image. d = data[0, 0] self.check(d, ims[0][0], ids=False) self.assertEqual(d.applied_operations, ims[0][0].applied_operations) # `is_batch==True`, should have all metadata and subset of all images. d = data[:, 0] self.check(d, list_data_collate([i[0] for i in ims]), ids=False) # `is_batch==True`, should have all metadata and subset of all images. d = data[..., -1] self.check(d, list_data_collate([i[..., -1] for i in ims]), ids=False) # `is_batch==False`, tuple split along batch dim. Should have individual # metadata. d = data.unbind(0) self.assertIsInstance(d, tuple) self.assertEqual(len(d), len(ims)) for _d, _im in zip(d, ims): self.check(_d, _im, ids=False) # `is_batch==False`, tuple split along batch dim. Should have individual # metadata. d = data.unbind(dim=0) self.assertIsInstance(d, tuple) self.assertEqual(len(d), len(ims)) for _d, _im in zip(d, ims): self.check(_d, _im, ids=False) self.assertEqual(_d.applied_operations, _im.applied_operations) # `is_batch==True`, tuple split along non-batch dim. Should have all metadata. d = data.unbind(-1) self.assertIsInstance(d, tuple) self.assertEqual(len(d), ims[0].shape[-1]) for _d in d: self.check_meta(_d, data) # `is_batch==True`, tuple split along non-batch dim. Should have all metadata. d = data.unbind(dim=-1) self.assertIsInstance(d, tuple) self.assertEqual(len(d), ims[0].shape[-1]) for _d in d: self.check_meta(_d, data)
def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: """ Update the metadata from the output of `MetaTensor.__torch_function__`. The output of `torch.Tensor.__torch_function__` could be a single object or a sequence of them. Hence, in `MetaTensor.__torch_function__` we convert them to a list of not already, and then we loop across each element, processing metadata as necessary. For each element, if not of type `MetaTensor`, then nothing to do. Args: rets: the output from `torch.Tensor.__torch_function__`, which has been converted to a list in `MetaTensor.__torch_function__` if it wasn't already a `Sequence`. func: the torch function that was applied. Examples might be `torch.squeeze` or `torch.Tensor.__add__`. We need this since the metadata need to be treated differently if a batch of data is considered. For example, slicing (`torch.Tensor.__getitem__`) the ith element of the 0th dimension of a batch of data should return a ith tensor with the ith metadata. args: positional arguments that were passed to `func`. kwargs: keyword arguments that were passed to `func`. Returns: A sequence with the same number of elements as `rets`. For each element, if the input type was not `MetaTensor`, then no modifications will have been made. If global parameters have been set to false (e.g., `not get_track_meta()`), then any `MetaTensor` will be converted to `torch.Tensor`. Else, metadata will be propagated as necessary (see :py:func:`MetaTensor._copy_meta`). """ out = [] metas = None is_batch = any( x.is_batch for x in MetaObj.flatten_meta_objs(args, kwargs.values()) if hasattr(x, "is_batch")) for idx, ret in enumerate(rets): # if not `MetaTensor`, nothing to do. if not isinstance(ret, MetaTensor): pass # if not tracking, convert to `torch.Tensor`. elif not get_track_meta(): ret = ret.as_tensor() # else, handle the `MetaTensor` metadata. else: meta_args = MetaObj.flatten_meta_objs(args, kwargs.values()) ret.is_batch = is_batch ret.copy_meta_from(meta_args, copy_attr=not is_batch) # the following is not implemented but the network arch may run into this case: # if func == torch.cat and any(m.is_batch if hasattr(m, "is_batch") else False for m in meta_args): # raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.") # If we have a batch of data, then we need to be careful if a slice of # the data is returned. Depending on how the data are indexed, we return # some or all of the metadata, and the return object may or may not be a # batch of data (e.g., `batch[:,-1]` versus `batch[0]`). if is_batch: # if indexing e.g., `batch[0]` if func == torch.Tensor.__getitem__: batch_idx = args[1] if isinstance(batch_idx, Sequence): batch_idx = batch_idx[0] # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the # first element will be `slice(None, None, None)` and `Ellipsis`, # respectively. Don't need to do anything with the metadata. if batch_idx not in (slice(None, None, None), Ellipsis, None) and idx == 0: ret_meta = decollate_batch(args[0], detach=False)[batch_idx] if isinstance(ret_meta, list): # e.g. batch[0:2], re-collate ret_meta = list_data_collate(ret_meta) else: # e.g. `batch[0]` or `batch[0, 1]`, batch index is an integer ret_meta.is_batch = False ret.__dict__ = ret_meta.__dict__.copy() # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. # But we only want to split the batch if the `unbind` is along the 0th # dimension. elif func == torch.Tensor.unbind: if len(args) > 1: dim = args[1] elif "dim" in kwargs: dim = kwargs["dim"] else: dim = 0 if dim == 0: if metas is None: metas = decollate_batch(args[0], detach=False) ret.__dict__ = metas[idx].__dict__.copy() ret.is_batch = False out.append(ret) # if the input was a tuple, then return it as a tuple return tuple(out) if isinstance(rets, tuple) else out