def run_interaction(self, train, compose): data = [{ "image": np.ones((1, 2, 2, 2)).astype(np.float32), "label": np.ones((1, 2, 2, 2)) } for _ in range(5)] network = torch.nn.Linear(2, 2) lr = 1e-3 opt = torch.optim.SGD(network.parameters(), lr) loss = torch.nn.L1Loss() train_transforms = Compose([ FindAllValidSlicesd(label="label", sids="sids"), AddInitialSeedPointd(label="label", guidance="guidance", sids="sids"), AddGuidanceSignald(image="image", guidance="guidance"), ToTensord(keys=("image", "label")), ]) dataset = Dataset(data, transform=train_transforms) data_loader = torch.utils.data.DataLoader(dataset, batch_size=5) iteration_transforms = [ Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys=["image", "label", "pred"]), FindDiscrepancyRegionsd(label="label", pred="pred", discrepancy="discrepancy"), AddRandomGuidanced(guidance="guidance", discrepancy="discrepancy", probability="probability"), AddGuidanceSignald(image="image", guidance="guidance"), ToTensord(keys=("image", "label")), ] iteration_transforms = Compose( iteration_transforms) if compose else iteration_transforms i = Interaction(transforms=iteration_transforms, train=train, max_interactions=5) self.assertEqual(len(i.transforms.transforms), 6, "Mismatch in expected transforms") # set up engine engine = SupervisedTrainer( device=torch.device("cpu"), max_epochs=1, train_data_loader=data_loader, network=network, optimizer=opt, loss_function=loss, iteration_update=i, ) engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one) engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one) engine.run() self.assertIsNotNone(engine.state.batch[0].get("guidance"), "guidance is missing") self.assertEqual(engine.state.best_metric, 9)
def train_pre_transforms(self, context: Context): # Dataset preparation t: List[Any] = [ LoadImaged(keys=("image", "label")), AddChanneld(keys=("image", "label")), SpatialCropForegroundd(keys=("image", "label"), source_key="label", spatial_size=self.roi_size), Resized(keys=("image", "label"), spatial_size=self.model_size, mode=("area", "nearest")), NormalizeIntensityd(keys="image", subtrahend=208.0, divisor=388.0), # type: ignore ] if self.dimension == 3: t.append(FindAllValidSlicesd(label="label", sids="sids")) t.extend([ AddInitialSeedPointd(label="label", guidance="guidance", sids="sids"), AddGuidanceSignald(image="image", guidance="guidance"), EnsureTyped(keys=("image", "label"), device=context.device), SelectItemsd(keys=("image", "label", "guidance")), ]) return t
def train_pre_transforms(self, context: Context): return [ LoadImaged(keys=("image", "label"), dtype=np.uint8), FilterImaged(keys="image", min_size=5), AsChannelFirstd(keys="image"), AddChanneld(keys="label"), ToTensord(keys="image"), TorchVisiond(keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04), ToNumpyd(keys="image"), RandRotate90d(keys=("image", "label"), prob=0.5, spatial_axes=(0, 1)), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), AddInitialSeedPointExd(label="label", guidance="guidance"), AddGuidanceSignald(image="image", guidance="guidance", number_intensity_ch=3), EnsureTyped(keys=("image", "label")), ]
def pre_transforms(self, data=None) -> Sequence[Callable]: t = [ LoadImaged(keys="image"), AsChannelFirstd(keys="image"), Spacingd(keys="image", pixdim=[1.0] * self.dimension, mode="bilinear"), AddGuidanceFromPointsd(ref_image="image", guidance="guidance", dimensions=self.dimension), ] if self.dimension == 2: t.append(Fetch2DSliced(keys="image", guidance="guidance")) t.extend([ AddChanneld(keys="image"), SpatialCropGuidanced(keys="image", guidance="guidance", spatial_size=self.spatial_size), Resized(keys="image", spatial_size=self.model_size, mode="area"), ResizeGuidanced(guidance="guidance", ref_image="image"), NormalizeIntensityd(keys="image", subtrahend=208, divisor=388), # type: ignore AddGuidanceSignald(image="image", guidance="guidance"), EnsureTyped(keys="image", device=data.get("device") if data else None), ]) return t
def pre_transforms(self, data): return [ LoadImaged(keys="image"), AsChannelFirstd(keys="image"), Spacingd(keys="image", pixdim=[1.0, 1.0, 1.0], mode="bilinear"), AddGuidanceFromPointsd(ref_image="image", guidance="guidance", dimensions=3), AddChanneld(keys="image"), SpatialCropGuidanced(keys="image", guidance="guidance", spatial_size=self.spatial_size), Resized(keys="image", spatial_size=self.model_size, mode="area"), ResizeGuidanced(guidance="guidance", ref_image="image"), NormalizeIntensityd(keys="image", subtrahend=208, divisor=388), AddGuidanceSignald(image="image", guidance="guidance"), ]
def get_click_transforms(self, context: Context): return [ Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys=("image", "label", "pred")), FindDiscrepancyRegionsd(label="label", pred="pred", discrepancy="discrepancy"), AddRandomGuidanced(guidance="guidance", discrepancy="discrepancy", probability="probability"), AddGuidanceSignald(image="image", guidance="guidance"), ToTensord(keys=("image", "label")), ]
def get_click_transforms(): return Compose([ Activationsd(keys='pred', sigmoid=True), ToNumpyd(keys=('image', 'label', 'pred', 'probability', 'guidance')), FindDiscrepancyRegionsd(label='label', pred='pred', discrepancy='discrepancy', batched=True), AddRandomGuidanced(guidance='guidance', discrepancy='discrepancy', probability='probability', batched=True), AddGuidanceSignald(image='image', guidance='guidance', batched=True), ToTensord(keys=('image', 'label')) ])
def get_click_transforms(): return Compose([ Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys=("image", "label", "pred")), FindDiscrepancyRegionsd(label="label", pred="pred", discrepancy="discrepancy"), AddRandomGuidanced( guidance="guidance", discrepancy="discrepancy", probability="probability", ), AddGuidanceSignald(image="image", guidance="guidance"), EnsureTyped(keys=("image", "label")), ])
def pre_transforms(self, data=None): return [ LoadImagePatchd(keys="image", conversion="RGB", dtype=np.uint8), FilterImaged(keys="image"), AsChannelFirstd(keys="image"), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), AddClickGuidanced(image="image", guidance="guidance"), AddGuidanceSignald(image="image", guidance="guidance", number_intensity_ch=3), EnsureTyped(keys="image", device=data.get("device") if data else None), ]
def pre_transforms(self, data=None): return [ LoadImaged(keys="image"), AsChannelFirstd(keys="image"), Spacingd(keys="image", pixdim=[1.0, 1.0], mode="bilinear"), AddGuidanceFromPointsd(ref_image="image", guidance="guidance", dimensions=2), Fetch2DSliced(keys="image", guidance="guidance"), AddChanneld(keys="image"), SpatialCropGuidanced(keys="image", guidance="guidance", spatial_size=[256, 256]), Resized(keys="image", spatial_size=[256, 256], mode="area"), ResizeGuidanced(guidance="guidance", ref_image="image"), NormalizeIntensityd(keys="image", subtrahend=208, divisor=388), # type: ignore AddGuidanceSignald(image="image", guidance="guidance"), ]
def get_pre_transforms(roi_size, model_size, dimensions): t = [ LoadImaged(keys=('image', 'label')), AddChanneld(keys=('image', 'label')), SpatialCropForegroundd(keys=('image', 'label'), source_key='label', spatial_size=roi_size), Resized(keys=('image', 'label'), spatial_size=model_size, mode=('area', 'nearest')), NormalizeIntensityd(keys='image', subtrahend=208.0, divisor=388.0) ] if dimensions == 3: t.append(FindAllValidSlicesd(label='label', sids='sids')) t.extend([ AddInitialSeedPointd(label='label', guidance='guidance', sids='sids'), AddGuidanceSignald(image='image', guidance='guidance'), ToTensord(keys=('image', 'label')) ]) return Compose(t)
def get_pre_transforms(roi_size, model_size, dimensions): t = [ LoadImaged(keys=("image", "label")), AddChanneld(keys=("image", "label")), SpatialCropForegroundd(keys=("image", "label"), source_key="label", spatial_size=roi_size), Resized(keys=("image", "label"), spatial_size=model_size, mode=("area", "nearest")), NormalizeIntensityd(keys="image", subtrahend=208.0, divisor=388.0), ] if dimensions == 3: t.append(FindAllValidSlicesd(label="label", sids="sids")) t.extend([ AddInitialSeedPointd(label="label", guidance="guidance", sids="sids"), AddGuidanceSignald(image="image", guidance="guidance"), EnsureTyped(keys=("image", "label")), ]) return Compose(t)
def test_correct_results(self, arguments, input_data, expected_result): result = AddGuidanceSignald(**arguments)(input_data) np.testing.assert_allclose(result["image"], expected_result, rtol=1e-5)