Exemple #1
0
 def train_pre_transforms(self, context: Context):
     # Dataset preparation
     t: List[Any] = [
         LoadImaged(keys=("image", "label")),
         AddChanneld(keys=("image", "label")),
         SpatialCropForegroundd(keys=("image", "label"),
                                source_key="label",
                                spatial_size=self.roi_size),
         Resized(keys=("image", "label"),
                 spatial_size=self.model_size,
                 mode=("area", "nearest")),
         NormalizeIntensityd(keys="image", subtrahend=208.0,
                             divisor=388.0),  # type: ignore
     ]
     if self.dimension == 3:
         t.append(FindAllValidSlicesd(label="label", sids="sids"))
     t.extend([
         AddInitialSeedPointd(label="label",
                              guidance="guidance",
                              sids="sids"),
         AddGuidanceSignald(image="image", guidance="guidance"),
         EnsureTyped(keys=("image", "label"), device=context.device),
         SelectItemsd(keys=("image", "label", "guidance")),
     ])
     return t
Exemple #2
0
    def pre_transforms(self, data=None):
        t = [
            LoadImaged(keys="image", reader="ITKReader"),
            EnsureChannelFirstd(keys="image"),
            Orientationd(keys="image", axcodes="RAS"),
            ScaleIntensityRanged(keys="image", a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
        ]
        if self.type == InferType.DEEPEDIT:
            t.extend(
                [
                    AddGuidanceFromPointsCustomd(ref_image="image", guidance="guidance", label_names=self.labels),
                    Resized(keys="image", spatial_size=self.spatial_size, mode="area"),
                    ResizeGuidanceMultipleLabelCustomd(guidance="guidance", ref_image="image"),
                    AddGuidanceSignalCustomd(
                        keys="image", guidance="guidance", number_intensity_ch=self.number_intensity_ch
                    ),
                ]
            )
        else:
            t.extend(
                [
                    Resized(keys="image", spatial_size=self.spatial_size, mode="area"),
                    DiscardAddGuidanced(
                        keys="image", label_names=self.labels, number_intensity_ch=self.number_intensity_ch
                    ),
                ]
            )

        t.append(EnsureTyped(keys="image", device=data.get("device") if data else None))
        return t
Exemple #3
0
 def train_pre_transforms(self, context: Context):
     return [
         LoadImaged(keys=("image", "label"), dtype=np.uint8),
         FilterImaged(keys="image", min_size=5),
         AsChannelFirstd(keys="image"),
         AddChanneld(keys="label"),
         ToTensord(keys="image"),
         TorchVisiond(keys="image",
                      name="ColorJitter",
                      brightness=64.0 / 255.0,
                      contrast=0.75,
                      saturation=0.25,
                      hue=0.04),
         ToNumpyd(keys="image"),
         RandRotate90d(keys=("image", "label"),
                       prob=0.5,
                       spatial_axes=(0, 1)),
         ScaleIntensityRangeD(keys="image",
                              a_min=0.0,
                              a_max=255.0,
                              b_min=-1.0,
                              b_max=1.0),
         AddInitialSeedPointExd(label="label", guidance="guidance"),
         AddGuidanceSignald(image="image",
                            guidance="guidance",
                            number_intensity_ch=3),
         EnsureTyped(keys=("image", "label")),
     ]
 def test_dict(self):
     # simulate complicated input data
     test_data = {
         "img": np.array([1.0, 2.0], dtype=np.float32),
         "meta": {
             "dims": 3,
             "size": np.array([1, 2, 3]),
             "path": "temp/test"
         },
         "extra": None,
     }
     for dtype in ("tensor", "numpy"):
         result = EnsureTyped(keys="data", data_type=dtype, device="cpu")({
             "data":
             test_data
         })["data"]
         self.assertTrue(isinstance(result, dict))
         self.assertTrue(
             isinstance(result["img"],
                        torch.Tensor if dtype == "tensor" else np.ndarray))
         torch.testing.assert_allclose(result["img"],
                                       torch.as_tensor([1.0, 2.0]))
         self.assertTrue(
             isinstance(result["meta"]["size"],
                        torch.Tensor if dtype == "tensor" else np.ndarray))
         torch.testing.assert_allclose(result["meta"]["size"],
                                       torch.as_tensor([1, 2, 3]))
         self.assertEqual(result["meta"]["path"], "temp/test")
         self.assertEqual(result["extra"], None)
 def test_string(self):
     for dtype in ("tensor", "numpy"):
         # string input
         result = EnsureTyped(keys="data", data_type=dtype)({
             "data":
             "test_string"
         })["data"]
         self.assertTrue(isinstance(result, str))
         self.assertEqual(result, "test_string")
         # numpy array of string
         result = EnsureTyped(keys="data", data_type=dtype)({
             "data":
             np.array(["test_string"])
         })["data"]
         self.assertTrue(isinstance(result, np.ndarray))
         self.assertEqual(result[0], "test_string")
Exemple #6
0
 def pre_transforms(self, data=None) -> Sequence[Callable]:
     t = [
         LoadImaged(keys="image"),
         AsChannelFirstd(keys="image"),
         Spacingd(keys="image",
                  pixdim=[1.0] * self.dimension,
                  mode="bilinear"),
         AddGuidanceFromPointsd(ref_image="image",
                                guidance="guidance",
                                dimensions=self.dimension),
     ]
     if self.dimension == 2:
         t.append(Fetch2DSliced(keys="image", guidance="guidance"))
     t.extend([
         AddChanneld(keys="image"),
         SpatialCropGuidanced(keys="image",
                              guidance="guidance",
                              spatial_size=self.spatial_size),
         Resized(keys="image", spatial_size=self.model_size, mode="area"),
         ResizeGuidanced(guidance="guidance", ref_image="image"),
         NormalizeIntensityd(keys="image", subtrahend=208,
                             divisor=388),  # type: ignore
         AddGuidanceSignald(image="image", guidance="guidance"),
         EnsureTyped(keys="image",
                     device=data.get("device") if data else None),
     ])
     return t
Exemple #7
0
 def train_pre_transforms(self, context: Context):
     return [
         LoadImaged(keys=("image", "label")),
         AddChanneld(keys=("image", "label")),
         Spacingd(
             keys=("image", "label"),
             pixdim=(1.0, 1.0, 1.0),
             mode=("bilinear", "nearest"),
         ),
         ScaleIntensityRanged(keys="image",
                              a_min=-57,
                              a_max=164,
                              b_min=0.0,
                              b_max=1.0,
                              clip=True),
         CropForegroundd(keys=("image", "label"), source_key="image"),
         EnsureTyped(keys=("image", "label"), device=context.device),
         RandCropByPosNegLabeld(
             keys=("image", "label"),
             label_key="label",
             spatial_size=(96, 96, 96),
             pos=1,
             neg=1,
             num_samples=4,
             image_key="image",
             image_threshold=0,
         ),
         RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
         SelectItemsd(keys=("image", "label")),
     ]
Exemple #8
0
 def post_transforms(self, data=None) -> Sequence[Callable]:
     return [
         EnsureTyped(keys="pred", device=data.get("device") if data else None),
         Activationsd(keys="pred", softmax=True),
         AsDiscreted(keys="pred", argmax=True),
         SqueezeDimd(keys="pred", dim=0),
         ToNumpyd(keys="pred"),
         Restored(keys="pred", ref_image="image"),
     ]
Exemple #9
0
 def post_transforms(self, data=None) -> Sequence[Callable]:
     return [
         EnsureTyped(keys="pred",
                     device=data.get("device") if data else None),
         Activationsd(keys="pred", sigmoid=True),
         AsDiscreted(keys="pred", threshold_values=True, logit_thresh=0.5),
         ToNumpyd(keys="pred"),
         RestoreLabeld(keys="pred", ref_image="image", mode="nearest"),
         AsChannelLastd(keys="pred"),
     ]
Exemple #10
0
 def post_transforms(self, data=None) -> Sequence[Callable]:
     return [
         EnsureTyped(keys="pred",
                     device=data.get("device") if data else None),
         Activationsd(keys="pred", softmax=True),
         AsDiscreted(keys="pred", argmax=True),
         ToNumpyd(keys="pred"),
         Restored(keys="pred", ref_image="image"),
         BoundingBoxd(keys="pred", result="result", bbox="bbox"),
     ]
Exemple #11
0
 def test_list_tuple(self):
     for dtype in ("tensor", "numpy"):
         result = EnsureTyped(keys="data", data_type=dtype)({
             "data": [[1, 2], [3, 4]]
         })["data"]
         self.assertTrue(isinstance(result, list))
         self.assertTrue(
             isinstance(result[0][1],
                        torch.Tensor if dtype == "tensor" else np.ndarray))
         torch.testing.assert_allclose(result[1][0], torch.as_tensor(3))
         # tuple of numpy arrays
         result = EnsureTyped(keys="data", data_type=dtype)({
             "data": (np.array([1, 2]), np.array([3, 4]))
         })["data"]
         self.assertTrue(isinstance(result, tuple))
         self.assertTrue(
             isinstance(result[0],
                        torch.Tensor if dtype == "tensor" else np.ndarray))
         torch.testing.assert_allclose(result[1], torch.as_tensor([3, 4]))
 def post_transforms(self, data=None) -> Sequence[Callable]:
     return [
         EnsureTyped(keys="pred", device=data.get("device") if data else None),
         Activationsd(keys="pred", softmax=len(self.labels) > 1, sigmoid=len(self.labels) == 1),
         AsDiscreted(keys="pred", argmax=len(self.labels) > 1, threshold=0.5 if len(self.labels) == 1 else None),
         SqueezeDimd(keys="pred", dim=0),
         ToNumpyd(keys=("image", "pred")),
         PostFilterLabeld(keys="pred", image="image"),
         FindContoursd(keys="pred", labels=self.labels),
     ]
Exemple #13
0
 def pre_transforms(self, data=None):
     return [
         LoadImagePatchd(keys="image",
                         conversion="RGB",
                         dtype=np.uint8,
                         padding=False),
         AsChannelFirstd(keys="image"),
         AddClickSignalsd(image="image"),
         EnsureTyped(keys="image",
                     device=data.get("device") if data else None),
     ]
Exemple #14
0
 def train_post_transforms(self, context: Context):
     return [
         EnsureTyped(keys="pred", device=context.device),
         Activationsd(keys="pred",
                      softmax=len(self._labels) > 1,
                      sigmoid=len(self._labels) == 1),
         AsDiscreted(
             keys=("pred", "label"),
             argmax=(True, False),
             to_onehot=(len(self._labels) + 1, len(self._labels) + 1),
         ),
     ]
Exemple #15
0
 def pre_transforms(self, data=None) -> Sequence[Callable]:
     return [
         LoadImaged(keys="image", reader="ITKReader"),
         EnsureChannelFirstd(keys="image"),
         Spacingd(keys="image", pixdim=self.target_spacing),
         ScaleIntensityRanged(keys="image",
                              a_min=-175,
                              a_max=250,
                              b_min=0.0,
                              b_max=1.0,
                              clip=True),
         EnsureTyped(keys="image"),
     ]
Exemple #16
0
 def pre_transforms(self, data=None) -> Sequence[Callable]:
     return [
         LoadImaged(keys="image"),
         AddChanneld(keys="image"),
         Spacingd(keys="image", pixdim=[1.0, 1.0, 1.0]),
         ScaleIntensityRanged(keys="image",
                              a_min=-57,
                              a_max=164,
                              b_min=0.0,
                              b_max=1.0,
                              clip=True),
         EnsureTyped(keys="image"),
     ]
Exemple #17
0
def get_click_transforms():
    return Compose([
        Activationsd(keys="pred", sigmoid=True),
        ToNumpyd(keys=("image", "label", "pred")),
        FindDiscrepancyRegionsd(label="label",
                                pred="pred",
                                discrepancy="discrepancy"),
        AddRandomGuidanced(
            guidance="guidance",
            discrepancy="discrepancy",
            probability="probability",
        ),
        AddGuidanceSignald(image="image", guidance="guidance"),
        EnsureTyped(keys=("image", "label")),
    ])
Exemple #18
0
 def val_pre_transforms(self, context: Context):
     return [
         LoadImaged(keys=("image", "label"), reader="ITKReader"),
         EnsureChannelFirstd(keys=("image", "label")),
         Spacingd(keys=("image", "label"),
                  pixdim=self.target_spacing,
                  mode=("bilinear", "nearest")),
         ScaleIntensityRanged(keys="image",
                              a_min=-175,
                              a_max=250,
                              b_min=0.0,
                              b_max=1.0,
                              clip=True),
         EnsureTyped(keys=("image", "label")),
         SelectItemsd(keys=("image", "label")),
     ]
Exemple #19
0
def get_xforms(mode="train", keys=("image", "label")):
    """returns a composed transform for train/val/infer."""

    xforms = [
        LoadImaged(keys),
        AddChanneld(keys),
        Orientationd(keys, axcodes="LPS"),
        Spacingd(keys,
                 pixdim=(1.25, 1.25, 5.0),
                 mode=("bilinear", "nearest")[:len(keys)]),
        ScaleIntensityRanged(keys[0],
                             a_min=-1000.0,
                             a_max=500.0,
                             b_min=0.0,
                             b_max=1.0,
                             clip=True),
    ]
    if mode == "train":
        xforms.extend([
            SpatialPadd(keys, spatial_size=(192, 192, -1),
                        mode="reflect"),  # ensure at least 192x192
            RandAffined(
                keys,
                prob=0.15,
                rotate_range=(
                    0.05, 0.05, None
                ),  # 3 parameters control the transform on 3 dimensions
                scale_range=(0.1, 0.1, None),
                mode=("bilinear", "nearest"),
                as_tensor_output=False,
            ),
            RandCropByPosNegLabeld(keys,
                                   label_key=keys[1],
                                   spatial_size=(192, 192, 16),
                                   num_samples=3),
            RandGaussianNoised(keys[0], prob=0.15, std=0.01),
            RandFlipd(keys, spatial_axis=0, prob=0.5),
            RandFlipd(keys, spatial_axis=1, prob=0.5),
            RandFlipd(keys, spatial_axis=2, prob=0.5),
        ])
        dtype = (np.float32, np.uint8)
    if mode == "val":
        dtype = (np.float32, np.uint8)
    if mode == "infer":
        dtype = (np.float32, )
    xforms.extend([CastToTyped(keys, dtype=dtype), EnsureTyped(keys)])
    return monai.transforms.Compose(xforms)
 def pre_transforms(self, data=None):
     return [
         LoadImagePatchd(keys="image", conversion="RGB", dtype=np.uint8),
         FilterImaged(keys="image"),
         AsChannelFirstd(keys="image"),
         ScaleIntensityRangeD(keys="image",
                              a_min=0.0,
                              a_max=255.0,
                              b_min=-1.0,
                              b_max=1.0),
         AddClickGuidanced(image="image", guidance="guidance"),
         AddGuidanceSignald(image="image",
                            guidance="guidance",
                            number_intensity_ch=3),
         EnsureTyped(keys="image",
                     device=data.get("device") if data else None),
     ]
Exemple #21
0
 def val_pre_transforms(self, context: Context):
     return [
         LoadImaged(keys=("image", "label")),
         AddChanneld(keys=("image", "label")),
         Spacingd(
             keys=("image", "label"),
             pixdim=(1.0, 1.0, 1.0),
             mode=("bilinear", "nearest"),
         ),
         ScaleIntensityRanged(keys="image",
                              a_min=-57,
                              a_max=164,
                              b_min=0.0,
                              b_max=1.0,
                              clip=True),
         CropForegroundd(keys=("image", "label"), source_key="image"),
         EnsureTyped(keys=("image", "label"), device=context.device),
     ]
 def test_single_input(self):
     test_datas = [5, 5.0, False, np.asarray(5), torch.tensor(5)]
     if torch.cuda.is_available():
         test_datas.append(test_datas[-1].cuda())
     for test_data in test_datas:
         for dtype in ("tensor", "numpy"):
             result = EnsureTyped(keys="data", data_type=dtype)({
                 "data":
                 test_data
             })["data"]
             self.assertTrue(
                 isinstance(
                     result,
                     torch.Tensor if dtype == "tensor" else np.ndarray))
             if isinstance(test_data, bool):
                 self.assertFalse(result)
             else:
                 assert_allclose(result, test_data, type_test=False)
             self.assertEqual(result.ndim, 0)
Exemple #23
0
def get_pre_transforms(roi_size, model_size, dimensions):
    t = [
        LoadImaged(keys=("image", "label")),
        AddChanneld(keys=("image", "label")),
        SpatialCropForegroundd(keys=("image", "label"),
                               source_key="label",
                               spatial_size=roi_size),
        Resized(keys=("image", "label"),
                spatial_size=model_size,
                mode=("area", "nearest")),
        NormalizeIntensityd(keys="image", subtrahend=208.0, divisor=388.0),
    ]
    if dimensions == 3:
        t.append(FindAllValidSlicesd(label="label", sids="sids"))
    t.extend([
        AddInitialSeedPointd(label="label", guidance="guidance", sids="sids"),
        AddGuidanceSignald(image="image", guidance="guidance"),
        EnsureTyped(keys=("image", "label")),
    ])
    return Compose(t)
Exemple #24
0
 def post_transforms(self, data=None) -> Sequence[Callable]:
     largest_cc = False if not data else data.get("largest_cc", False)
     applied_labels = list(self.labels.values()) if isinstance(
         self.labels, dict) else self.labels
     t = [
         EnsureTyped(keys="pred",
                     device=data.get("device") if data else None),
         Activationsd(keys="pred",
                      softmax=len(self.labels) > 1,
                      sigmoid=len(self.labels) == 1),
         AsDiscreted(keys="pred",
                     argmax=len(self.labels) > 1,
                     threshold=0.5 if len(self.labels) == 1 else None),
     ]
     if largest_cc:
         t.append(
             KeepLargestConnectedComponentd(keys="pred",
                                            applied_labels=applied_labels))
     t.extend([
         ToNumpyd(keys="pred"),
         Restored(keys="pred", ref_image="image"),
     ])
     return t
Exemple #25
0
 def train_pre_transforms(self, context: Context):
     return [
         LoadImaged(keys=("image", "label"), reader="ITKReader"),
         NormalizeLabelsInDatasetd(
             keys="label",
             label_names=self._labels),  # Specially for missing labels
         EnsureChannelFirstd(keys=("image", "label")),
         Spacingd(keys=("image", "label"),
                  pixdim=self.target_spacing,
                  mode=("bilinear", "nearest")),
         CropForegroundd(keys=("image", "label"), source_key="image"),
         SpatialPadd(keys=("image", "label"),
                     spatial_size=self.spatial_size),
         ScaleIntensityRanged(keys="image",
                              a_min=-175,
                              a_max=250,
                              b_min=0.0,
                              b_max=1.0,
                              clip=True),
         RandCropByPosNegLabeld(
             keys=("image", "label"),
             label_key="label",
             spatial_size=self.spatial_size,
             pos=1,
             neg=1,
             num_samples=self.num_samples,
             image_key="image",
             image_threshold=0,
         ),
         EnsureTyped(keys=("image", "label"), device=context.device),
         RandFlipd(keys=("image", "label"), spatial_axis=[0], prob=0.10),
         RandFlipd(keys=("image", "label"), spatial_axis=[1], prob=0.10),
         RandFlipd(keys=("image", "label"), spatial_axis=[2], prob=0.10),
         RandRotate90d(keys=("image", "label"), prob=0.10, max_k=3),
         RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
         SelectItemsd(keys=("image", "label")),
     ]
 def test_array_input(self):
     test_datas = [
         np.array([[1, 2], [3, 4]]),
         torch.as_tensor([[1, 2], [3, 4]])
     ]
     if torch.cuda.is_available():
         test_datas.append(test_datas[-1].cuda())
     for test_data in test_datas:
         for dtype in ("tensor", "NUMPY"):
             result = EnsureTyped(
                 keys="data",
                 data_type=dtype,
                 dtype=np.float32 if dtype == "NUMPY" else None,
                 device="cpu")({
                     "data": test_data
                 })["data"]
             if dtype == "NUMPY":
                 self.assertTrue(result.dtype == np.float32)
             self.assertTrue(
                 isinstance(
                     result,
                     torch.Tensor if dtype == "tensor" else np.ndarray))
             assert_allclose(result, test_data, type_test=False)
             self.assertTupleEqual(result.shape, (2, 2))
Exemple #27
0
    def test_invert(self):
        set_determinism(seed=0)
        im_fname, seg_fname = (
            make_nifti_image(i)
            for i in create_test_image_3d(101, 100, 107, noise_max=100))
        transform = Compose([
            LoadImaged(KEYS),
            AddChanneld(KEYS),
            Orientationd(KEYS, "RPS"),
            Spacingd(KEYS,
                     pixdim=(1.2, 1.01, 0.9),
                     mode=["bilinear", "nearest"],
                     dtype=np.float32),
            ScaleIntensityd("image", minv=1, maxv=10),
            RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(KEYS, prob=0.5),
            RandRotate90d(KEYS, spatial_axes=(1, 2)),
            RandZoomd(KEYS,
                      prob=0.5,
                      min_zoom=0.5,
                      max_zoom=1.1,
                      keep_size=True),
            RandRotated(KEYS,
                        prob=0.5,
                        range_x=np.pi,
                        mode="bilinear",
                        align_corners=True,
                        dtype=np.float64),
            RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"),
            ResizeWithPadOrCropd(KEYS, 100),
            # test EnsureTensor for complicated dict data and invert it
            CopyItemsd(PostFix.meta("image"), times=1, names="test_dict"),
            # test to support Tensor, Numpy array and dictionary when inverting
            EnsureTyped(keys=["image", "test_dict"]),
            ToTensord("image"),
            CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]),
            CopyItemsd("label",
                       times=2,
                       names=["label_inverted", "label_inverted1"]),
            CopyItemsd("image",
                       times=2,
                       names=["image_inverted", "image_inverted1"]),
        ])
        data = [{"image": im_fname, "label": seg_fname} for _ in range(12)]

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available(
        ) else 2

        dataset = CacheDataset(data, transform=transform, progress=False)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=5)
        inverter = Invertd(
            # `image` was not copied, invert the original value directly
            keys=["image_inverted", "label_inverted", "test_dict"],
            transform=transform,
            orig_keys=["label", "label", "test_dict"],
            meta_keys=[
                PostFix.meta("image_inverted"),
                PostFix.meta("label_inverted"), None
            ],
            orig_meta_keys=[
                PostFix.meta("label"),
                PostFix.meta("label"), None
            ],
            nearest_interp=True,
            to_tensor=[True, False, False],
            device="cpu",
        )

        inverter_1 = Invertd(
            # `image` was not copied, invert the original value directly
            keys=["image_inverted1", "label_inverted1"],
            transform=transform,
            orig_keys=["image", "image"],
            meta_keys=[
                PostFix.meta("image_inverted1"),
                PostFix.meta("label_inverted1")
            ],
            orig_meta_keys=[PostFix.meta("image"),
                            PostFix.meta("image")],
            nearest_interp=[True, False],
            to_tensor=[True, True],
            device="cpu",
        )

        expected_keys = [
            "image",
            "image_inverted",
            "image_inverted1",
            PostFix.meta("image_inverted1"),
            PostFix.meta("image_inverted"),
            PostFix.meta("image"),
            "image_transforms",
            "label",
            "label_inverted",
            "label_inverted1",
            PostFix.meta("label_inverted1"),
            PostFix.meta("label_inverted"),
            PostFix.meta("label"),
            "label_transforms",
            "test_dict",
            "test_dict_transforms",
        ]
        # execute 1 epoch
        for d in loader:
            d = decollate_batch(d)
            for item in d:
                item = inverter(item)
                item = inverter_1(item)

                self.assertListEqual(sorted(item), expected_keys)
                self.assertTupleEqual(item["image"].shape[1:], (100, 100, 100))
                self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100))
                # check the nearest interpolation mode
                i = item["image_inverted"]
                torch.testing.assert_allclose(
                    i.to(torch.uint8).to(torch.float), i.to(torch.float))
                self.assertTupleEqual(i.shape[1:], (100, 101, 107))
                i = item["label_inverted"]
                torch.testing.assert_allclose(
                    i.to(torch.uint8).to(torch.float), i.to(torch.float))
                self.assertTupleEqual(i.shape[1:], (100, 101, 107))
                # test inverted test_dict
                self.assertTrue(
                    isinstance(item["test_dict"]["affine"], np.ndarray))
                self.assertTrue(
                    isinstance(item["test_dict"]["filename_or_obj"], str))

                # check the case that different items use different interpolation mode to invert transforms
                d = item["image_inverted1"]
                # if the interpolation mode is nearest, accumulated diff should be smaller than 1
                self.assertLess(
                    torch.sum(
                        d.to(torch.float) -
                        d.to(torch.uint8).to(torch.float)).item(), 1.0)
                self.assertTupleEqual(d.shape, (1, 100, 101, 107))

                d = item["label_inverted1"]
                # if the interpolation mode is not nearest, accumulated diff should be greater than 10000
                self.assertGreater(
                    torch.sum(
                        d.to(torch.float) -
                        d.to(torch.uint8).to(torch.float)).item(), 10000.0)
                self.assertTupleEqual(d.shape, (1, 100, 101, 107))

        # check labels match
        reverted = item["label_inverted"].detach().cpu().numpy().astype(
            np.int32)
        original = LoadImaged(KEYS)(data[-1])["label"]
        n_good = np.sum(np.isclose(reverted, original, atol=1e-3))
        reverted_name = item[PostFix.meta("label_inverted")]["filename_or_obj"]
        original_name = data[-1]["label"]
        self.assertEqual(reverted_name, original_name)
        print("invert diff", reverted.size - n_good)
        # 25300: 2 workers (cpu, non-macos)
        # 1812: 0 workers (gpu or macos)
        # 1821: windows torch 1.10.0
        self.assertTrue((reverted.size - n_good) in (34007, 1812, 1821),
                        f"diff.  {reverted.size - n_good}")

        set_determinism(seed=None)
def evaluate(args):
    # initialize Horovod library
    hvd.init()
    # Horovod limits CPU threads to be used per worker
    torch.set_num_threads(1)

    if hvd.local_rank() == 0 and not os.path.exists(args.dir):
        # create 16 random image, mask paris for evaluation
        print(f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(16):
            im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            ScaleIntensityd(keys="img"),
            EnsureTyped(keys=["img", "seg"]),
        ]
    )

    # create a evaluation data loader
    val_ds = Dataset(data=val_files, transform=val_transforms)
    # create a evaluation data sampler
    val_sampler = DistributedSampler(val_ds, shuffle=False, num_replicas=hvd.size(), rank=hvd.rank())
    # when supported, use "forkserver" to spawn dataloader workers instead of "fork" to prevent
    # issues with Infiniband implementations that are not fork-safe
    multiprocessing_context = None
    if hasattr(mp, "_supports_context") and mp._supports_context and "forkserver" in mp.get_all_start_methods():
        multiprocessing_context = "forkserver"
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(
        val_ds,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        sampler=val_sampler,
        multiprocessing_context=multiprocessing_context,
    )
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
    post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{hvd.local_rank()}")
    torch.cuda.set_device(device)
    model = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    if hvd.rank() == 0:
        # load model parameters for evaluation
        model.load_state_dict(torch.load("final_model.pth"))
    # Horovod broadcasts parameters
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)

    model.eval()
    with torch.no_grad():
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
            dice_metric(y_pred=val_outputs, y=val_labels)

        metric = dice_metric.aggregate().item()
        dice_metric.reset()

        if hvd.rank() == 0:
            print("evaluation metric:", metric)
Exemple #29
0
def get_task_transforms(mode, task_id, pos_sample_num, neg_sample_num,
                        num_samples):
    if mode != "test":
        keys = ["image", "label"]
    else:
        keys = ["image"]

    load_transforms = [
        LoadImaged(keys=keys),
        EnsureChannelFirstd(keys=keys),
    ]
    # 2. sampling
    sample_transforms = [
        PreprocessAnisotropic(
            keys=keys,
            clip_values=clip_values[task_id],
            pixdim=spacing[task_id],
            normalize_values=normalize_values[task_id],
            model_mode=mode,
        ),
    ]
    # 3. spatial transforms
    if mode == "train":
        other_transforms = [
            SpatialPadd(keys=["image", "label"],
                        spatial_size=patch_size[task_id]),
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=patch_size[task_id],
                pos=pos_sample_num,
                neg=neg_sample_num,
                num_samples=num_samples,
                image_key="image",
                image_threshold=0,
            ),
            RandZoomd(
                keys=["image", "label"],
                min_zoom=0.9,
                max_zoom=1.2,
                mode=("trilinear", "nearest"),
                align_corners=(True, None),
                prob=0.15,
            ),
            RandGaussianNoised(keys=["image"], std=0.01, prob=0.15),
            RandGaussianSmoothd(
                keys=["image"],
                sigma_x=(0.5, 1.15),
                sigma_y=(0.5, 1.15),
                sigma_z=(0.5, 1.15),
                prob=0.15,
            ),
            RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15),
            RandFlipd(["image", "label"], spatial_axis=[0], prob=0.5),
            RandFlipd(["image", "label"], spatial_axis=[1], prob=0.5),
            RandFlipd(["image", "label"], spatial_axis=[2], prob=0.5),
            CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)),
            EnsureTyped(keys=["image", "label"]),
        ]
    elif mode == "validation":
        other_transforms = [
            CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)),
            EnsureTyped(keys=["image", "label"]),
        ]
    else:
        other_transforms = [
            CastToTyped(keys=["image"], dtype=(np.float32)),
            EnsureTyped(keys=["image"]),
        ]

    all_transforms = load_transforms + sample_transforms + other_transforms
    return Compose(all_transforms)
def train(args):
    # disable logging for processes except 0 on every node
    if args.local_rank != 0:
        f = open(os.devnull, "w")
        sys.stdout = sys.stderr = f
    elif not os.path.exists(args.dir):
        # create 40 random image, mask paris for training
        print(
            f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(40):
            im, seg = create_test_image_3d(128,
                                           128,
                                           128,
                                           num_seg_classes=1,
                                           channel_dim=-1)
            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))

    # initialize the distributed training process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    train_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadImaged(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys="img"),
        RandCropByPosNegLabeld(keys=["img", "seg"],
                               label_key="seg",
                               spatial_size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
        EnsureTyped(keys=["img", "seg"]),
    ])

    # partition dataset based on current rank number, every rank trains with its own data
    # it can avoid duplicated caching content in each rank, but will not do global shuffle before every epoch
    data_part = partition_dataset(
        data=train_files,
        num_partitions=dist.get_world_size(),
        shuffle=True,
        even_divisible=True,
    )[dist.get_rank()]

    train_ds = SmartCacheDataset(
        data=data_part,
        transform=train_transforms,
        replace_rate=0.2,
        cache_num=
        15,  # we suppose to use 2 ranks in this example, every rank has 20 training images
        num_init_workers=2,
        num_replace_workers=2,
    )
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=2,
                              pin_memory=True)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)
    model = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)
    # wrap the model with DistributedDataParallel module
    model = DistributedDataParallel(model, device_ids=[device])

    # start a typical PyTorch training
    epoch_loss_values = list()
    # start the replacement thread of SmartCache
    train_ds.start()

    for epoch in range(5):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{5}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(
                device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = math.ceil(len(train_ds) / train_loader.batch_size)
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        # replace 20% of cache content for next epoch
        train_ds.update_cache()
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    # stop replacement thread of SmartCache
    train_ds.shutdown()
    print(f"train completed, epoch losses: {epoch_loss_values}")
    if dist.get_rank() == 0:
        # all processes should see same parameters as they all start from same
        # random parameters and gradients are synchronized in backward passes,
        # therefore, saving it in one process is sufficient
        torch.save(model.state_dict(), "final_model.pth")
    dist.destroy_process_group()