Ejemplo n.º 1
0
    def run_interaction(self, train, compose):
        data = []
        for i in range(5):
            data.append({"image": torch.tensor([float(i)]), "label": torch.tensor([float(i)])})
        network = torch.nn.Linear(1, 1)
        lr = 1e-3
        opt = torch.optim.SGD(network.parameters(), lr)
        loss = torch.nn.L1Loss()
        dataset = Dataset(data, transform=None)
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=5)

        iteration_transforms = [Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys="pred")]
        iteration_transforms = Compose(iteration_transforms) if compose else iteration_transforms

        i = Interaction(transforms=iteration_transforms, train=train, max_interactions=5)
        self.assertEqual(len(i.transforms.transforms), 2, "Mismatch in expected transforms")

        # set up engine
        engine = SupervisedTrainer(
            device=torch.device("cpu"),
            max_epochs=1,
            train_data_loader=data_loader,
            network=network,
            optimizer=opt,
            loss_function=loss,
            iteration_update=i,
        )
        engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one)
        engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one)

        engine.run()
        self.assertIsNotNone(engine.state.batch.get("probability"), "Probability is missing")
        self.assertEqual(engine.state.best_metric, 9)
Ejemplo n.º 2
0
    def run_interaction(self, train, compose):
        data = [{
            "image": np.ones((1, 2, 2, 2)).astype(np.float32),
            "label": np.ones((1, 2, 2, 2))
        } for _ in range(5)]
        network = torch.nn.Linear(2, 2)
        lr = 1e-3
        opt = torch.optim.SGD(network.parameters(), lr)
        loss = torch.nn.L1Loss()
        train_transforms = Compose([
            FindAllValidSlicesd(label="label", sids="sids"),
            AddInitialSeedPointd(label="label",
                                 guidance="guidance",
                                 sids="sids"),
            AddGuidanceSignald(image="image", guidance="guidance"),
            ToTensord(keys=("image", "label")),
        ])
        dataset = Dataset(data, transform=train_transforms)
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=5)

        iteration_transforms = [
            Activationsd(keys="pred", sigmoid=True),
            ToNumpyd(keys=["image", "label", "pred"]),
            FindDiscrepancyRegionsd(label="label",
                                    pred="pred",
                                    discrepancy="discrepancy"),
            AddRandomGuidanced(guidance="guidance",
                               discrepancy="discrepancy",
                               probability="probability"),
            AddGuidanceSignald(image="image", guidance="guidance"),
            ToTensord(keys=("image", "label")),
        ]
        iteration_transforms = Compose(
            iteration_transforms) if compose else iteration_transforms

        i = Interaction(transforms=iteration_transforms,
                        train=train,
                        max_interactions=5)
        self.assertEqual(len(i.transforms.transforms), 6,
                         "Mismatch in expected transforms")

        # set up engine
        engine = SupervisedTrainer(
            device=torch.device("cpu"),
            max_epochs=1,
            train_data_loader=data_loader,
            network=network,
            optimizer=opt,
            loss_function=loss,
            iteration_update=i,
        )
        engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED,
                                 add_one)
        engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED,
                                 add_one)

        engine.run()
        self.assertIsNotNone(engine.state.batch[0].get("guidance"),
                             "guidance is missing")
        self.assertEqual(engine.state.best_metric, 9)
Ejemplo n.º 3
0
    def run_interaction(self, train):
        label_names = {"spleen": 1, "background": 0}
        np.random.seed(0)
        data = [
            {
                "image": np.random.randint(0, 256, size=(1, 15, 15, 15)).astype(np.float32),
                "label": np.random.randint(0, 2, size=(1, 15, 15, 15)),
                "label_names": label_names,
            }
            for _ in range(5)
        ]
        network = torch.nn.Conv3d(3, len(label_names), 1)
        lr = 1e-3
        opt = torch.optim.Adam(network.parameters(), lr)
        loss = DiceCELoss(to_onehot_y=True, softmax=True)
        pre_transforms = Compose(
            [
                FindAllValidSlicesMissingLabelsd(keys="label", sids="sids"),
                AddInitialSeedPointMissingLabelsd(keys="label", guidance="guidance", sids="sids"),
                AddGuidanceSignalDeepEditd(keys="image", guidance="guidance", number_intensity_ch=1),
                ToTensord(keys=("image", "label")),
            ]
        )
        dataset = Dataset(data, transform=pre_transforms)
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=5)

        iteration_transforms = [
            FindDiscrepancyRegionsDeepEditd(keys="label", pred="pred", discrepancy="discrepancy"),
            AddRandomGuidanceDeepEditd(
                keys="NA", guidance="guidance", discrepancy="discrepancy", probability="probability"
            ),
            AddGuidanceSignalDeepEditd(keys="image", guidance="guidance", number_intensity_ch=1),
            ToTensord(keys=("image", "label")),
        ]
        post_transforms = [
            Activationsd(keys="pred", softmax=True),
            AsDiscreted(keys=("pred", "label"), argmax=(True, False), to_onehot=len(label_names)),
            SplitPredsLabeld(keys="pred"),
            ToTensord(keys=("image", "label")),
        ]
        iteration_transforms = Compose(iteration_transforms)
        post_transforms = Compose(post_transforms)

        i = Interaction(
            deepgrow_probability=1.0,
            transforms=iteration_transforms,
            click_probability_key="probability",
            train=train,
            label_names=label_names,
        )
        self.assertEqual(len(i.transforms.transforms), 4, "Mismatch in expected transforms")

        # set up engine
        engine = SupervisedTrainer(
            device=torch.device("cpu"),
            max_epochs=1,
            train_data_loader=data_loader,
            network=network,
            optimizer=opt,
            loss_function=loss,
            postprocessing=post_transforms,
            iteration_update=i,
        )
        engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one)
        engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one)

        engine.run()
        self.assertIsNotNone(engine.state.batch[0].get("guidance"), "guidance is missing")
        self.assertEqual(engine.state.best_metric, 1)