def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff): for key in keys: orig = orig_d[key] fwd_bck = fwd_bck_d[key] if isinstance(fwd_bck, torch.Tensor): fwd_bck = fwd_bck.cpu().numpy() unmodified = unmodified_d[key] if isinstance(orig, np.ndarray): mean_diff = np.mean(np.abs(orig - fwd_bck)) resized = ResizeWithPadOrCrop(orig.shape[1:])(unmodified) if isinstance(resized, torch.Tensor): resized = resized.detach().cpu().numpy() unmodded_diff = np.mean(np.abs(orig - resized)) try: self.assertLessEqual(mean_diff, acceptable_diff) except AssertionError: print( f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}" ) if orig[0].ndim == 1: print("orig", orig[0]) print("fwd_bck", fwd_bck[0]) print("unmod", unmodified[0]) raise
def test_inverse_inferred_seg(self, extra_transform): test_data = [] for _ in range(20): image, label = create_test_image_2d(100, 101) test_data.append({ "image": image, "label": label.astype(np.float32) }) batch_size = 10 # num workers = 0 for mac num_workers = 2 if sys.platform == "linux" else 0 transforms = Compose([ AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), extra_transform ]) dataset = CacheDataset(test_data, transform=transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) device = "cuda" if torch.cuda.is_available() else "cpu" model = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(1, )).to(device) data = first(loader) self.assertEqual(data["image"].shape[0], batch_size * NUM_SAMPLES) labels = data["label"].to(device) self.assertIsInstance(labels, MetaTensor) segs = model(labels).detach().cpu() segs_decollated = decollate_batch(segs) self.assertIsInstance(segs_decollated[0], MetaTensor) # inverse of individual segmentation seg_metatensor = first(segs_decollated) # test to convert interpolation mode for 1 data of model output batch convert_applied_interp_mode(seg_metatensor.applied_operations, mode="nearest", align_corners=None) # manually invert the last crop samples xform = seg_metatensor.applied_operations.pop(-1) shape_before_extra_xform = xform["orig_size"] resizer = ResizeWithPadOrCrop(spatial_size=shape_before_extra_xform) with resizer.trace_transform(False): seg_metatensor = resizer(seg_metatensor) with allow_missing_keys_mode(transforms): inv_seg = transforms.inverse({"label": seg_metatensor})["label"] self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)
def __call__(self, image, label, params=None): if params is None: params = self._random() self.params = params if self.params['width'] > 0 and self.params['height'] > 0: image = ResizeWithPadOrCrop( [self.params['width'], self.params['height']])(image) label = ResizeWithPadOrCrop( [self.params['width'], self.params['height']])(label) return image, label
def __call__( self, img: np.ndarray, msk: Optional[np.ndarray] = None, center: Optional[tuple] = None, z_axis: Optional[int] = None, ): """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ if self.mask_data is None and msk is None: raise ValueError("Unknown mask_data.") mask_data_ = np.array([[1]]) if self.mask_data is not None and msk is None: mask_data_ = self.mask_data > 0 if msk is not None: mask_data_ = msk > 0 mask_data_ = np.asarray(mask_data_) if mask_data_.shape[0] != 1 and mask_data_.shape[0] != img.shape[0]: raise ValueError( "When mask_data is not single channel, mask_data channels must match img, " f"got img={img.shape[0]} mask_data={mask_data_.shape[0]}.") z_axis_ = z_axis if z_axis is not None else self.z_axis if center is None: center = self.get_center_pos(mask_data_, z_axis_) if self.crop_mode in ["single", "parallel"]: size_ = self.get_new_spatial_size(z_axis_) size_ = list(map(int, size_)) slice_ = SpatialCrop(roi_center=center, roi_size=size_)(img) if np.any(slice_.shape[1:] != size_): slice_ = ResizeWithPadOrCrop(spatial_size=size_)(slice_) return np.moveaxis(slice_.squeeze(0), z_axis_, 0) else: cross_slices = np.zeros(shape=(3, ) + self.roi_size) for k in range(3): size_ = np.insert(self.roi_size, k, 1) slice_ = SpatialCrop(roi_center=center, roi_size=size_)(img) if np.any(slice_.shape[1:] != size_): slice_ = ResizeWithPadOrCrop(spatial_size=size_)(slice_) cross_slices[k] = slice_.squeeze() return cross_slices
def __call__(self, data): d = dict(data) current_shape = d[self.ref_image].shape[1:] dims = len(current_shape) croppad = ResizeWithPadOrCrop(spatial_size=self.spatial_size) # get guidance following pad or crop to spatial_size new_guidance = [] for guidance in d[self.guidance]: if guidance: signal = np.zeros(current_shape) for point in guidance: if dims == 2: signal[point[0], point[1]] = 1.0 else: signal[point[0], point[1], point[2]] = 1.0 signal = croppad(signal[np.newaxis, :]).squeeze( 0) # croppad requires channel dim new_guidance.append( np.argwhere(signal == 1.0).astype(int).tolist()) else: new_guidance.append([]) d[self.guidance] = new_guidance return d
def test_pad_shape(self, input_param, input_shape, expected_shape): for p in TEST_NDARRAYS_ALL: if isinstance( p(0), torch.Tensor) and ("constant_values" in input_param or input_param["mode"] == "reflect"): continue padcropper = ResizeWithPadOrCrop(**input_param) result = padcropper(p(np.zeros(input_shape))) np.testing.assert_allclose(result.shape, expected_shape) result = padcropper(p(np.zeros(input_shape)), mode="constant") np.testing.assert_allclose(result.shape, expected_shape) self.assertIsInstance(result, MetaTensor) self.assertEqual(len(result.applied_operations), 1) inv = padcropper.inverse(result) self.assertTupleEqual(inv.shape, input_shape) self.assertIsInstance(inv, MetaTensor) self.assertEqual(inv.applied_operations, [])
def test_pad_shape(self, input_param, input_shape, expected_shape): for p in TEST_NDARRAYS: if isinstance( p(0), torch.Tensor) and ("constant_values" in input_param or input_param["mode"] == "reflect"): continue paddcroper = ResizeWithPadOrCrop(**input_param) result = paddcroper(p(np.zeros(input_shape))) np.testing.assert_allclose(result.shape, expected_shape) result = paddcroper(p(np.zeros(input_shape)), mode="constant") np.testing.assert_allclose(result.shape, expected_shape)
def __call__( self, img: np.ndarray, label: Optional[np.ndarray] = None, image: Optional[np.ndarray] = None, fg_indices: Optional[np.ndarray] = None, bg_indices: Optional[np.ndarray] = None, ) -> List[np.ndarray]: """ Args: img: input data to crop samples from based on the pos/neg ratio of `label` and `image`. Assumes `img` is a channel-first array. label: the label image that is used for finding foreground/background, if None, use `self.label`. image: optional image data to help select valid area, can be same as `img` or another image array. use ``label == 0 & image > image_threshold`` to select the negative sample(background) center. so the crop center will only exist on valid image area. if None, use `self.image`. fg_indices: foreground indices to randomly select crop centers, need to provide `fg_indices` and `bg_indices` together. bg_indices: background indices to randomly select crop centers, need to provide `fg_indices` and `bg_indices` together. """ if label is None: label = self.label if label is None: raise ValueError("label should be provided.") if image is None: image = self.image if fg_indices is None or bg_indices is None: if self.fg_indices is not None and self.bg_indices is not None: fg_indices = self.fg_indices bg_indices = self.bg_indices else: fg_indices, bg_indices = map_binary_to_indices( label, image, self.image_threshold) if self.target_label is not None: label = (label == self.target_label).astype(np.uint8) self.randomize(label, fg_indices, bg_indices, image) results: List[np.ndarray] = [] if self.centers is not None: for center in self.centers: if np.any(np.greater(self.spatial_size, img.shape[1:])): cropper = ResizeWithPadOrCrop( spatial_size=self.spatial_size) else: cropper = SpatialCrop( roi_center=tuple(center), spatial_size=self.spatial_size) # type: ignore results.append(cropper(img)) return results
def test_invert(self): set_determinism(seed=0) im_fname = make_nifti_image(create_test_image_3d(101, 100, 107, noise_max=100)[1]) # label image, discrete data = [im_fname for _ in range(12)] transform = Compose( [ LoadImage(image_only=True), EnsureChannelFirst(), Orientation("RPS"), Spacing(pixdim=(1.2, 1.01, 0.9), mode="bilinear", dtype=np.float32), RandFlip(prob=0.5, spatial_axis=[1, 2]), RandAxisFlip(prob=0.5), RandRotate90(prob=0, spatial_axes=(1, 2)), RandZoom(prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotate(prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True, dtype=np.float64), RandAffine(prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCrop(100), CastToType(dtype=torch.uint8), ] ) # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2 dataset = Dataset(data, transform=transform) self.assertIsInstance(transform.inverse(dataset[0]), MetaTensor) loader = DataLoader(dataset, num_workers=num_workers, batch_size=1) inverter = Invert(transform=transform, nearest_interp=True, device="cpu") for d in loader: d = decollate_batch(d) for item in d: orig = deepcopy(item) i = inverter(item) self.assertTupleEqual(orig.shape[1:], (100, 100, 100)) # check the nearest interpolation mode torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) # check labels match reverted = i.detach().cpu().numpy().astype(np.int32) original = LoadImage(image_only=True)(data[-1]) n_good = np.sum(np.isclose(reverted, original.numpy(), atol=1e-3)) reverted_name = i.meta["filename_or_obj"] original_name = original.meta["filename_or_obj"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) self.assertTrue((reverted.size - n_good) < 300000, f"diff. {reverted.size - n_good}") set_determinism(seed=None)
def __call__( self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]: d = dict(data) label = d[self.label_key] image = d[self.image_key] if self.image_key else None fg_indices = d.get( self.fg_indices_key) if self.fg_indices_key is not None else None bg_indices = d.get( self.bg_indices_key) if self.bg_indices_key is not None else None if self.target_label is not None: label = (label == self.target_label).astype(np.uint8) self.randomize(label, fg_indices, bg_indices, image) if not isinstance(self.spatial_size, tuple): raise TypeError( f"Expect spatial_size to be tuple, but got {type(self.spatial_size)}" ) if self.centers is None: raise AssertionError results: List[Dict[Hashable, np.ndarray]] = [{} for _ in range(self.num_samples)] for i, center in enumerate(self.centers): for key in self.key_iterator(d): img = d[key] if np.greater(self.spatial_size, img.shape[1:]).any(): cropper = ResizeWithPadOrCrop( spatial_size=self.spatial_size) else: cropper = SpatialCrop( roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore results[i][key] = cropper(img) # fill in the extra keys with unmodified data for key in set(data.keys()).difference(set(self.keys)): results[i][key] = data[key] # add `patch_index` to the meta data for key in self.key_iterator(d): meta_data_key = f"{key}_{self.meta_key_postfix}" if meta_data_key not in results[i]: results[i][meta_data_key] = {} # type: ignore results[i][meta_data_key][Key.PATCH_INDEX] = i return results
def test_pad_shape(self, input_param, input_shape, expected_shape): paddcroper = ResizeWithPadOrCrop(**input_param) result = paddcroper(np.zeros(input_shape)) np.testing.assert_allclose(result.shape, expected_shape) result = paddcroper(np.zeros(input_shape), mode="constant") np.testing.assert_allclose(result.shape, expected_shape)