def test_inverse_inferred_seg(self): test_data = [] for _ in range(20): image, label = create_test_image_2d(100, 101) test_data.append({"image": image, "label": label.astype(np.float32)}) batch_size = 10 # num workers = 0 for mac num_workers = 2 if sys.platform != "darwin" else 0 transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), CenterSpatialCropd(KEYS, (110, 99))]) num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform)) dataset = CacheDataset(test_data, transform=transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) device = "cuda" if torch.cuda.is_available() else "cpu" model = UNet( dimensions=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(2,), ).to(device) data = first(loader) labels = data["label"].to(device) segs = model(labels).detach().cpu() label_transform_key = "label" + InverseKeys.KEY_SUFFIX segs_dict = {"label": segs, label_transform_key: data[label_transform_key]} segs_dict_decollated = decollate_batch(segs_dict) # inverse of individual segmentation seg_dict = first(segs_dict_decollated) # test to convert interpolation mode for 1 data of model output batch convert_inverse_interp_mode(seg_dict, mode="nearest", align_corners=None) with allow_missing_keys_mode(transforms): inv_seg = transforms.inverse(seg_dict)["label"] self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms) self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape) # Inverse of batch batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=no_collation) with allow_missing_keys_mode(transforms): inv_batch = batch_inverter(segs_dict) self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape)
def __call__(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ for output_key, batch_key, nearest_interp in zip(self.output_keys, self.batch_keys, self.nearest_interp): transform_key = batch_key + InverseKeys.KEY_SUFFIX if transform_key not in engine.state.batch: warnings.warn(f"all the transforms on `{batch_key}` are not InvertibleTransform.") continue transform_info = engine.state.batch[transform_key] if nearest_interp: transform_info = convert_inverse_interp_mode( trans_info=deepcopy(transform_info), mode="nearest", align_corners=None, ) segs_dict = { batch_key: engine.state.output[output_key].detach().cpu(), transform_key: transform_info, } meta_dict_key = f"{batch_key}_{self.meta_key_postfix}" if meta_dict_key in engine.state.batch: segs_dict[meta_dict_key] = engine.state.batch[meta_dict_key] with allow_missing_keys_mode(self.transform): # type: ignore inverted_key = f"{output_key}_{self.postfix}" engine.state.output[inverted_key] = [self._totensor(i[batch_key]) for i in self.inverter(segs_dict)]
def test_inverse_inferred_seg(self, extra_transform): test_data = [] for _ in range(20): image, label = create_test_image_2d(100, 101) test_data.append({ "image": image, "label": label.astype(np.float32) }) batch_size = 10 # num workers = 0 for mac num_workers = 2 if sys.platform == "linux" else 0 transforms = Compose([ AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), extra_transform ]) dataset = CacheDataset(test_data, transform=transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) device = "cuda" if torch.cuda.is_available() else "cpu" model = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(1, )).to(device) data = first(loader) self.assertEqual(data["image"].shape[0], batch_size * NUM_SAMPLES) labels = data["label"].to(device) self.assertIsInstance(labels, MetaTensor) segs = model(labels).detach().cpu() segs_decollated = decollate_batch(segs) self.assertIsInstance(segs_decollated[0], MetaTensor) # inverse of individual segmentation seg_metatensor = first(segs_decollated) # test to convert interpolation mode for 1 data of model output batch convert_applied_interp_mode(seg_metatensor.applied_operations, mode="nearest", align_corners=None) # manually invert the last crop samples xform = seg_metatensor.applied_operations.pop(-1) shape_before_extra_xform = xform["orig_size"] resizer = ResizeWithPadOrCrop(spatial_size=shape_before_extra_xform) with resizer.trace_transform(False): seg_metatensor = resizer(seg_metatensor) with allow_missing_keys_mode(transforms): inv_seg = transforms.inverse({"label": seg_metatensor})["label"] self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)
def __call__(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ for output_key, batch_key, nearest_interp, to_tensor, device, post_func in zip( self.output_keys, self.batch_keys, self.nearest_interp, self.to_tensor, self.device, self.post_func): transform_key = batch_key + InverseKeys.KEY_SUFFIX if transform_key not in engine.state.batch: warnings.warn( f"all the transforms on `{batch_key}` are not InvertibleTransform." ) continue transform_info = engine.state.batch[transform_key] if nearest_interp: transform_info = convert_inverse_interp_mode( trans_info=deepcopy(transform_info), mode="nearest", align_corners=None, ) output = engine.state.output[output_key] if isinstance(output, torch.Tensor): output = output.detach() segs_dict = { batch_key: output, transform_key: transform_info, } meta_dict_key = f"{batch_key}_{self.meta_key_postfix}" if meta_dict_key in engine.state.batch: segs_dict[meta_dict_key] = engine.state.batch[meta_dict_key] with allow_missing_keys_mode(self.transform): # type: ignore inverted = self.inverter(segs_dict) # save the inverted data into state.output inverted_key = f"{output_key}_{self.postfix}" engine.state.output[inverted_key] = [ post_func( self._totensor(i[batch_key]). to(device) if to_tensor else i[batch_key]) for i in inverted ] # save the inverted meta dict into state.batch if meta_dict_key in engine.state.batch: engine.state.batch[ f"{inverted_key}_{self.meta_key_postfix}"] = [ i.get(meta_dict_key) for i in inverted ]
def test_multiple(self): orig_states = [True, False] ts = [ SpatialPadd(["image", "label"], 10, allow_missing_keys=i) for i in orig_states ] with allow_missing_keys_mode(ts): for t in ts: self.assertTrue(t.allow_missing_keys) # and that transform works even though key is missing _ = t(self.data) for t, o_s in zip(ts, orig_states): self.assertEqual(t.allow_missing_keys, o_s)
def test_map_transform(self): for amk in [True, False]: t = SpatialPadd(["image", "label"], 10, allow_missing_keys=amk) with allow_missing_keys_mode(t): # check state is True self.assertTrue(t.allow_missing_keys) # and that transform works even though key is missing _ = t(self.data) # check it has returned to original state self.assertEqual(t.allow_missing_keys, amk) if not amk: # should fail because amks==False and key is missing with self.assertRaises(KeyError): _ = t(self.data)
def test_compose(self): amks = [True, False, True] t = Compose([ SpatialPadd(["image", "label"], 10, allow_missing_keys=amk) for amk in amks ]) with allow_missing_keys_mode(t): # check states are all True for _t in t.transforms: self.assertTrue(_t.allow_missing_keys) # and that transform works even though key is missing _ = t(self.data) # check they've returned to original state for _t, amk in zip(t.transforms, amks): self.assertEqual(_t.allow_missing_keys, amk) # should fail because not all amks==True and key is missing with self.assertRaises((KeyError, RuntimeError)): _ = t(self.data)
def test_array_transform(self): for t in [SpatialPad(10), Compose([SpatialPad(10)])]: with self.assertRaises(TypeError): with allow_missing_keys_mode(t): pass