def run_interaction(self, train, compose): data = [] for i in range(5): data.append({"image": torch.tensor([float(i)]), "label": torch.tensor([float(i)])}) network = torch.nn.Linear(1, 1) lr = 1e-3 opt = torch.optim.SGD(network.parameters(), lr) loss = torch.nn.L1Loss() dataset = Dataset(data, transform=None) data_loader = torch.utils.data.DataLoader(dataset, batch_size=5) iteration_transforms = [Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys="pred")] iteration_transforms = Compose(iteration_transforms) if compose else iteration_transforms i = Interaction(transforms=iteration_transforms, train=train, max_interactions=5) self.assertEqual(len(i.transforms.transforms), 2, "Mismatch in expected transforms") # set up engine engine = SupervisedTrainer( device=torch.device("cpu"), max_epochs=1, train_data_loader=data_loader, network=network, optimizer=opt, loss_function=loss, iteration_update=i, ) engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one) engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one) engine.run() self.assertIsNotNone(engine.state.batch.get("probability"), "Probability is missing") self.assertEqual(engine.state.best_metric, 9)
def run_interaction(self, train, compose): data = [{ "image": np.ones((1, 2, 2, 2)).astype(np.float32), "label": np.ones((1, 2, 2, 2)) } for _ in range(5)] network = torch.nn.Linear(2, 2) lr = 1e-3 opt = torch.optim.SGD(network.parameters(), lr) loss = torch.nn.L1Loss() train_transforms = Compose([ FindAllValidSlicesd(label="label", sids="sids"), AddInitialSeedPointd(label="label", guidance="guidance", sids="sids"), AddGuidanceSignald(image="image", guidance="guidance"), ToTensord(keys=("image", "label")), ]) dataset = Dataset(data, transform=train_transforms) data_loader = torch.utils.data.DataLoader(dataset, batch_size=5) iteration_transforms = [ Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys=["image", "label", "pred"]), FindDiscrepancyRegionsd(label="label", pred="pred", discrepancy="discrepancy"), AddRandomGuidanced(guidance="guidance", discrepancy="discrepancy", probability="probability"), AddGuidanceSignald(image="image", guidance="guidance"), ToTensord(keys=("image", "label")), ] iteration_transforms = Compose( iteration_transforms) if compose else iteration_transforms i = Interaction(transforms=iteration_transforms, train=train, max_interactions=5) self.assertEqual(len(i.transforms.transforms), 6, "Mismatch in expected transforms") # set up engine engine = SupervisedTrainer( device=torch.device("cpu"), max_epochs=1, train_data_loader=data_loader, network=network, optimizer=opt, loss_function=loss, iteration_update=i, ) engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one) engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one) engine.run() self.assertIsNotNone(engine.state.batch[0].get("guidance"), "guidance is missing") self.assertEqual(engine.state.best_metric, 9)
def run_interaction(self, train): label_names = {"spleen": 1, "background": 0} np.random.seed(0) data = [ { "image": np.random.randint(0, 256, size=(1, 15, 15, 15)).astype(np.float32), "label": np.random.randint(0, 2, size=(1, 15, 15, 15)), "label_names": label_names, } for _ in range(5) ] network = torch.nn.Conv3d(3, len(label_names), 1) lr = 1e-3 opt = torch.optim.Adam(network.parameters(), lr) loss = DiceCELoss(to_onehot_y=True, softmax=True) pre_transforms = Compose( [ FindAllValidSlicesMissingLabelsd(keys="label", sids="sids"), AddInitialSeedPointMissingLabelsd(keys="label", guidance="guidance", sids="sids"), AddGuidanceSignalDeepEditd(keys="image", guidance="guidance", number_intensity_ch=1), ToTensord(keys=("image", "label")), ] ) dataset = Dataset(data, transform=pre_transforms) data_loader = torch.utils.data.DataLoader(dataset, batch_size=5) iteration_transforms = [ FindDiscrepancyRegionsDeepEditd(keys="label", pred="pred", discrepancy="discrepancy"), AddRandomGuidanceDeepEditd( keys="NA", guidance="guidance", discrepancy="discrepancy", probability="probability" ), AddGuidanceSignalDeepEditd(keys="image", guidance="guidance", number_intensity_ch=1), ToTensord(keys=("image", "label")), ] post_transforms = [ Activationsd(keys="pred", softmax=True), AsDiscreted(keys=("pred", "label"), argmax=(True, False), to_onehot=len(label_names)), SplitPredsLabeld(keys="pred"), ToTensord(keys=("image", "label")), ] iteration_transforms = Compose(iteration_transforms) post_transforms = Compose(post_transforms) i = Interaction( deepgrow_probability=1.0, transforms=iteration_transforms, click_probability_key="probability", train=train, label_names=label_names, ) self.assertEqual(len(i.transforms.transforms), 4, "Mismatch in expected transforms") # set up engine engine = SupervisedTrainer( device=torch.device("cpu"), max_epochs=1, train_data_loader=data_loader, network=network, optimizer=opt, loss_function=loss, postprocessing=post_transforms, iteration_update=i, ) engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one) engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one) engine.run() self.assertIsNotNone(engine.state.batch[0].get("guidance"), "guidance is missing") self.assertEqual(engine.state.best_metric, 1)