예제 #1
0
    def run_interaction(self, train, compose):
        data = [{
            "image": np.ones((1, 2, 2, 2)).astype(np.float32),
            "label": np.ones((1, 2, 2, 2))
        } for _ in range(5)]
        network = torch.nn.Linear(2, 2)
        lr = 1e-3
        opt = torch.optim.SGD(network.parameters(), lr)
        loss = torch.nn.L1Loss()
        train_transforms = Compose([
            FindAllValidSlicesd(label="label", sids="sids"),
            AddInitialSeedPointd(label="label",
                                 guidance="guidance",
                                 sids="sids"),
            AddGuidanceSignald(image="image", guidance="guidance"),
            ToTensord(keys=("image", "label")),
        ])
        dataset = Dataset(data, transform=train_transforms)
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=5)

        iteration_transforms = [
            Activationsd(keys="pred", sigmoid=True),
            ToNumpyd(keys=["image", "label", "pred"]),
            FindDiscrepancyRegionsd(label="label",
                                    pred="pred",
                                    discrepancy="discrepancy"),
            AddRandomGuidanced(guidance="guidance",
                               discrepancy="discrepancy",
                               probability="probability"),
            AddGuidanceSignald(image="image", guidance="guidance"),
            ToTensord(keys=("image", "label")),
        ]
        iteration_transforms = Compose(
            iteration_transforms) if compose else iteration_transforms

        i = Interaction(transforms=iteration_transforms,
                        train=train,
                        max_interactions=5)
        self.assertEqual(len(i.transforms.transforms), 6,
                         "Mismatch in expected transforms")

        # set up engine
        engine = SupervisedTrainer(
            device=torch.device("cpu"),
            max_epochs=1,
            train_data_loader=data_loader,
            network=network,
            optimizer=opt,
            loss_function=loss,
            iteration_update=i,
        )
        engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED,
                                 add_one)
        engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED,
                                 add_one)

        engine.run()
        self.assertIsNotNone(engine.state.batch[0].get("guidance"),
                             "guidance is missing")
        self.assertEqual(engine.state.best_metric, 9)
예제 #2
0
 def train_pre_transforms(self, context: Context):
     # Dataset preparation
     t: List[Any] = [
         LoadImaged(keys=("image", "label")),
         AddChanneld(keys=("image", "label")),
         SpatialCropForegroundd(keys=("image", "label"),
                                source_key="label",
                                spatial_size=self.roi_size),
         Resized(keys=("image", "label"),
                 spatial_size=self.model_size,
                 mode=("area", "nearest")),
         NormalizeIntensityd(keys="image", subtrahend=208.0,
                             divisor=388.0),  # type: ignore
     ]
     if self.dimension == 3:
         t.append(FindAllValidSlicesd(label="label", sids="sids"))
     t.extend([
         AddInitialSeedPointd(label="label",
                              guidance="guidance",
                              sids="sids"),
         AddGuidanceSignald(image="image", guidance="guidance"),
         EnsureTyped(keys=("image", "label"), device=context.device),
         SelectItemsd(keys=("image", "label", "guidance")),
     ])
     return t
예제 #3
0
 def train_pre_transforms(self, context: Context):
     return [
         LoadImaged(keys=("image", "label"), dtype=np.uint8),
         FilterImaged(keys="image", min_size=5),
         AsChannelFirstd(keys="image"),
         AddChanneld(keys="label"),
         ToTensord(keys="image"),
         TorchVisiond(keys="image",
                      name="ColorJitter",
                      brightness=64.0 / 255.0,
                      contrast=0.75,
                      saturation=0.25,
                      hue=0.04),
         ToNumpyd(keys="image"),
         RandRotate90d(keys=("image", "label"),
                       prob=0.5,
                       spatial_axes=(0, 1)),
         ScaleIntensityRangeD(keys="image",
                              a_min=0.0,
                              a_max=255.0,
                              b_min=-1.0,
                              b_max=1.0),
         AddInitialSeedPointExd(label="label", guidance="guidance"),
         AddGuidanceSignald(image="image",
                            guidance="guidance",
                            number_intensity_ch=3),
         EnsureTyped(keys=("image", "label")),
     ]
예제 #4
0
 def pre_transforms(self, data=None) -> Sequence[Callable]:
     t = [
         LoadImaged(keys="image"),
         AsChannelFirstd(keys="image"),
         Spacingd(keys="image",
                  pixdim=[1.0] * self.dimension,
                  mode="bilinear"),
         AddGuidanceFromPointsd(ref_image="image",
                                guidance="guidance",
                                dimensions=self.dimension),
     ]
     if self.dimension == 2:
         t.append(Fetch2DSliced(keys="image", guidance="guidance"))
     t.extend([
         AddChanneld(keys="image"),
         SpatialCropGuidanced(keys="image",
                              guidance="guidance",
                              spatial_size=self.spatial_size),
         Resized(keys="image", spatial_size=self.model_size, mode="area"),
         ResizeGuidanced(guidance="guidance", ref_image="image"),
         NormalizeIntensityd(keys="image", subtrahend=208,
                             divisor=388),  # type: ignore
         AddGuidanceSignald(image="image", guidance="guidance"),
         EnsureTyped(keys="image",
                     device=data.get("device") if data else None),
     ])
     return t
예제 #5
0
 def pre_transforms(self, data):
     return [
         LoadImaged(keys="image"),
         AsChannelFirstd(keys="image"),
         Spacingd(keys="image", pixdim=[1.0, 1.0, 1.0], mode="bilinear"),
         AddGuidanceFromPointsd(ref_image="image", guidance="guidance", dimensions=3),
         AddChanneld(keys="image"),
         SpatialCropGuidanced(keys="image", guidance="guidance", spatial_size=self.spatial_size),
         Resized(keys="image", spatial_size=self.model_size, mode="area"),
         ResizeGuidanced(guidance="guidance", ref_image="image"),
         NormalizeIntensityd(keys="image", subtrahend=208, divisor=388),
         AddGuidanceSignald(image="image", guidance="guidance"),
     ]
예제 #6
0
 def get_click_transforms(self, context: Context):
     return [
         Activationsd(keys="pred", sigmoid=True),
         ToNumpyd(keys=("image", "label", "pred")),
         FindDiscrepancyRegionsd(label="label",
                                 pred="pred",
                                 discrepancy="discrepancy"),
         AddRandomGuidanced(guidance="guidance",
                            discrepancy="discrepancy",
                            probability="probability"),
         AddGuidanceSignald(image="image", guidance="guidance"),
         ToTensord(keys=("image", "label")),
     ]
예제 #7
0
def get_click_transforms():
    return Compose([
        Activationsd(keys='pred', sigmoid=True),
        ToNumpyd(keys=('image', 'label', 'pred', 'probability', 'guidance')),
        FindDiscrepancyRegionsd(label='label',
                                pred='pred',
                                discrepancy='discrepancy',
                                batched=True),
        AddRandomGuidanced(guidance='guidance',
                           discrepancy='discrepancy',
                           probability='probability',
                           batched=True),
        AddGuidanceSignald(image='image', guidance='guidance', batched=True),
        ToTensord(keys=('image', 'label'))
    ])
예제 #8
0
파일: train.py 프로젝트: wyli/tutorials
def get_click_transforms():
    return Compose([
        Activationsd(keys="pred", sigmoid=True),
        ToNumpyd(keys=("image", "label", "pred")),
        FindDiscrepancyRegionsd(label="label",
                                pred="pred",
                                discrepancy="discrepancy"),
        AddRandomGuidanced(
            guidance="guidance",
            discrepancy="discrepancy",
            probability="probability",
        ),
        AddGuidanceSignald(image="image", guidance="guidance"),
        EnsureTyped(keys=("image", "label")),
    ])
예제 #9
0
 def pre_transforms(self, data=None):
     return [
         LoadImagePatchd(keys="image", conversion="RGB", dtype=np.uint8),
         FilterImaged(keys="image"),
         AsChannelFirstd(keys="image"),
         ScaleIntensityRangeD(keys="image",
                              a_min=0.0,
                              a_max=255.0,
                              b_min=-1.0,
                              b_max=1.0),
         AddClickGuidanced(image="image", guidance="guidance"),
         AddGuidanceSignald(image="image",
                            guidance="guidance",
                            number_intensity_ch=3),
         EnsureTyped(keys="image",
                     device=data.get("device") if data else None),
     ]
예제 #10
0
 def pre_transforms(self, data=None):
     return [
         LoadImaged(keys="image"),
         AsChannelFirstd(keys="image"),
         Spacingd(keys="image", pixdim=[1.0, 1.0], mode="bilinear"),
         AddGuidanceFromPointsd(ref_image="image",
                                guidance="guidance",
                                dimensions=2),
         Fetch2DSliced(keys="image", guidance="guidance"),
         AddChanneld(keys="image"),
         SpatialCropGuidanced(keys="image",
                              guidance="guidance",
                              spatial_size=[256, 256]),
         Resized(keys="image", spatial_size=[256, 256], mode="area"),
         ResizeGuidanced(guidance="guidance", ref_image="image"),
         NormalizeIntensityd(keys="image", subtrahend=208,
                             divisor=388),  # type: ignore
         AddGuidanceSignald(image="image", guidance="guidance"),
     ]
예제 #11
0
def get_pre_transforms(roi_size, model_size, dimensions):
    t = [
        LoadImaged(keys=('image', 'label')),
        AddChanneld(keys=('image', 'label')),
        SpatialCropForegroundd(keys=('image', 'label'),
                               source_key='label',
                               spatial_size=roi_size),
        Resized(keys=('image', 'label'),
                spatial_size=model_size,
                mode=('area', 'nearest')),
        NormalizeIntensityd(keys='image', subtrahend=208.0, divisor=388.0)
    ]
    if dimensions == 3:
        t.append(FindAllValidSlicesd(label='label', sids='sids'))
    t.extend([
        AddInitialSeedPointd(label='label', guidance='guidance', sids='sids'),
        AddGuidanceSignald(image='image', guidance='guidance'),
        ToTensord(keys=('image', 'label'))
    ])
    return Compose(t)
예제 #12
0
파일: train.py 프로젝트: wyli/tutorials
def get_pre_transforms(roi_size, model_size, dimensions):
    t = [
        LoadImaged(keys=("image", "label")),
        AddChanneld(keys=("image", "label")),
        SpatialCropForegroundd(keys=("image", "label"),
                               source_key="label",
                               spatial_size=roi_size),
        Resized(keys=("image", "label"),
                spatial_size=model_size,
                mode=("area", "nearest")),
        NormalizeIntensityd(keys="image", subtrahend=208.0, divisor=388.0),
    ]
    if dimensions == 3:
        t.append(FindAllValidSlicesd(label="label", sids="sids"))
    t.extend([
        AddInitialSeedPointd(label="label", guidance="guidance", sids="sids"),
        AddGuidanceSignald(image="image", guidance="guidance"),
        EnsureTyped(keys=("image", "label")),
    ])
    return Compose(t)
 def test_correct_results(self, arguments, input_data, expected_result):
     result = AddGuidanceSignald(**arguments)(input_data)
     np.testing.assert_allclose(result["image"], expected_result, rtol=1e-5)