Exemple #1
0
    def test_type_cupy(self, input_param, input_data, expected_type):
        input_data = {k: cp.asarray(v) for k, v in input_data.items()}

        result = CastToTyped(**input_param)(input_data)
        for k, v in result.items():
            self.assertTrue(isinstance(v, cp.ndarray))
            self.assertEqual(v.dtype, expected_type[k])
Exemple #2
0
 def test_value_2d(self, data, expected_mask):
     test_dtype = [torch.float32, torch.float16]
     for dtype in test_dtype:
         data = CastToTyped(keys=["image", "boxes"], dtype=dtype)(data)
         transform_to_mask = BoxToMaskd(
             box_keys="boxes",
             box_mask_keys="box_mask",
             box_ref_image_keys="image",
             label_keys="labels",
             min_fg_label=0,
             ellipse_mask=False,
         )
         transform_to_box = MaskToBoxd(box_keys="boxes",
                                       box_mask_keys="box_mask",
                                       label_keys="labels",
                                       min_fg_label=0)
         data_mask = transform_to_mask(data)
         assert_allclose(data_mask["box_mask"],
                         expected_mask,
                         type_test=True,
                         device_test=True,
                         atol=1e-3)
         data_back = transform_to_box(data_mask)
         assert_allclose(data_back["boxes"],
                         data["boxes"],
                         type_test=False,
                         device_test=False,
                         atol=1e-3)
         assert_allclose(data_back["labels"],
                         data["labels"],
                         type_test=False,
                         device_test=False,
                         atol=1e-3)
def get_xforms_load(mode="load", keys=("image", "label")):
    """returns a composed transform."""
    xforms = [
        LoadImaged(keys),
        ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True),
    ]
    if mode == "load":
        dtype = (np.float32, np.uint8)
    xforms.extend([CastToTyped(keys, dtype=dtype)])
    return monai.transforms.Compose(xforms)
Exemple #4
0
def get_xforms_load(mode="load", keys=("image", "label")):
    """returns a composed transform for train/val/infer."""

    xforms = [
        LoadImaged(keys),
        AddChanneld(keys),
        Orientationd(keys, axcodes="LPS"),
        Spacingd(keys, pixdim=(1.25, 1.25, 5.0), mode=("bilinear", "nearest")[: len(keys)]),
        # ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True),
    ]
    if mode == "load":
        dtype = (np.int16, np.uint8)
    xforms.extend([CastToTyped(keys, dtype=dtype), ToTensord(keys)])
    return monai.transforms.Compose(xforms)
Exemple #5
0
def get_xforms(args, mode="train", keys=("image", "label")):
    """returns a composed transform for train/val/infer."""

    xforms = [
        LoadNiftid(keys),
        AddChanneld(keys),
        Orientationd(keys, axcodes="LPS"),
        Spacingd(keys,
                 pixdim=(1.25, 1.25, 5.0),
                 mode=("bilinear", "nearest")[:len(keys)]),
        ScaleIntensityRanged(keys[0],
                             a_min=-1000.0,
                             a_max=500.0,
                             b_min=0.0,
                             b_max=1.0,
                             clip=True),
    ]
    if mode == "train":
        xforms.extend([
            SpatialPadd(keys,
                        spatial_size=(args.patch_size, args.patch_size, -1),
                        mode="reflect"),  # ensure at least 192x192
            RandAffined(
                keys,
                prob=0.15,
                rotate_range=(-0.05, 0.05),
                scale_range=(-0.1, 0.1),
                mode=("bilinear", "nearest"),
                as_tensor_output=False,
            ),
            RandCropByPosNegLabeld(keys,
                                   label_key=keys[1],
                                   spatial_size=(args.patch_size,
                                                 args.patch_size,
                                                 args.n_slice),
                                   num_samples=3),
            RandGaussianNoised(keys[0], prob=0.15, std=0.01),
            RandFlipd(keys, spatial_axis=0, prob=0.5),
            RandFlipd(keys, spatial_axis=1, prob=0.5),
            RandFlipd(keys, spatial_axis=2, prob=0.5),
        ])
        dtype = (np.float32, np.uint8)
    if mode == "val":
        dtype = (np.float32, np.uint8)
    if mode == "infer":
        dtype = (np.float32, )
    xforms.extend([CastToTyped(keys, dtype=dtype), ToTensord(keys)])
    return monai.transforms.Compose(xforms)
Exemple #6
0
 def test_value_3d(
     self,
     keys,
     data,
     expected_convert_result,
     expected_zoom_result,
     expected_zoom_keepsize_result,
     expected_flip_result,
     expected_clip_result,
     expected_rotate_result,
 ):
     test_dtype = [torch.float32]
     for dtype in test_dtype:
         data = CastToTyped(keys=["image", "boxes"], dtype=dtype)(data)
         # test ConvertBoxToStandardModed
         transform_convert_mode = ConvertBoxModed(**keys)
         convert_result = transform_convert_mode(data)
         assert_allclose(convert_result["boxes"],
                         expected_convert_result,
                         type_test=True,
                         device_test=True,
                         atol=1e-3)
Exemple #7
0
def get_xforms_with_synthesis(mode="synthesis", keys=("image", "label"), keys2=("image", "label", "synthetic_lesion")):
    """returns a composed transform for train/val/infer."""

    xforms = [
        LoadImaged(keys),
        AddChanneld(keys),
        Orientationd(keys, axcodes="LPS"),
        Spacingd(keys, pixdim=(1.25, 1.25, 5.0), mode=("bilinear", "nearest")[: len(keys)]),
        ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True),
        CopyItemsd(keys,1, names=['image_1', 'label_1']),
    ]
    if mode == "synthesis":
        xforms.extend([
                  SpatialPadd(keys, spatial_size=(192, 192, -1), mode="reflect"),  # ensure at least 192x192
                  RandCropByPosNegLabeld(keys, label_key=keys[1], spatial_size=(192, 192, 16), num_samples=3),
                  TransCustom(keys, path_synthesis, read_cea_aug_slice2, 
                              pseudo_healthy_with_texture, scans_syns, decreasing_sequence, GEN=15,
                              POST_PROCESS=True, mask_outer_ring=True, new_value=.5),
                  RandAffined(
                      # keys,
                      keys2,
                      prob=0.15,
                      rotate_range=(0.05, 0.05, None),  # 3 parameters control the transform on 3 dimensions
                      scale_range=(0.1, 0.1, None), 
                      mode=("bilinear", "nearest", "bilinear"),
                      # mode=("bilinear", "nearest"),
                      as_tensor_output=False
                  ),
                  RandGaussianNoised((keys2[0],keys2[2]), prob=0.15, std=0.01),
                  # RandGaussianNoised(keys[0], prob=0.15, std=0.01),
                  RandFlipd(keys, spatial_axis=0, prob=0.5),
                  RandFlipd(keys, spatial_axis=1, prob=0.5),
                  RandFlipd(keys, spatial_axis=2, prob=0.5),
                  TransCustom2(0.333)
              ])
    dtype = (np.float32, np.uint8)
    # dtype = (np.float32, np.uint8, np.float32)
    xforms.extend([CastToTyped(keys, dtype=dtype)])
    return monai.transforms.Compose(xforms)
Exemple #8
0
 def test_value_3d_mask(self):
     test_dtype = [torch.float32, torch.float16]
     image = np.zeros((1, 32, 33, 34))
     boxes = np.array([[7, 8, 9, 10, 12, 13], [1, 3, 5, 2, 5, 9],
                       [0, 0, 0, 1, 1, 1]])
     data = {"image": image, "boxes": boxes, "labels": np.array((1, 0, 3))}
     for dtype in test_dtype:
         data = CastToTyped(keys=["image", "boxes"], dtype=dtype)(data)
         transform_to_mask = BoxToMaskd(
             box_keys="boxes",
             box_mask_keys="box_mask",
             box_ref_image_keys="image",
             label_keys="labels",
             min_fg_label=0,
             ellipse_mask=False,
         )
         transform_to_box = MaskToBoxd(box_keys="boxes",
                                       box_mask_keys="box_mask",
                                       label_keys="labels",
                                       min_fg_label=0)
         data_mask = transform_to_mask(data)
         assert_allclose(data_mask["box_mask"].shape, (3, 32, 33, 34),
                         type_test=True,
                         device_test=True,
                         atol=1e-3)
         data_back = transform_to_box(data_mask)
         assert_allclose(data_back["boxes"],
                         data["boxes"],
                         type_test=False,
                         device_test=False,
                         atol=1e-3)
         assert_allclose(data_back["labels"],
                         data["labels"],
                         type_test=False,
                         device_test=False,
                         atol=1e-3)
Exemple #9
0
def get_xforms_scans_or_synthetic_lesions(mode="scans",
                                          keys=("image", "label")):
    """returns a composed transform for scans or synthetic lesions."""
    xforms = [
        LoadImaged(keys),
        AddChanneld(keys),
        Orientationd(keys, axcodes="LPS"),
        Spacingd(keys,
                 pixdim=(1.25, 1.25, 5.0),
                 mode=("bilinear", "nearest")[:len(keys)]),
    ]
    dtype = (np.int16, np.uint8)
    if mode == "synthetic":
        xforms.extend([
            ScaleIntensityRanged(keys[0],
                                 a_min=-1000.0,
                                 a_max=500.0,
                                 b_min=0.0,
                                 b_max=1.0,
                                 clip=True),
        ])
        dtype = (np.float32, np.uint8)
    xforms.extend([CastToTyped(keys, dtype=dtype)])
    return monai.transforms.Compose(xforms)
Exemple #10
0
def get_task_transforms(mode, task_id, pos_sample_num, neg_sample_num,
                        num_samples):
    if mode != "test":
        keys = ["image", "label"]
    else:
        keys = ["image"]

    load_transforms = [
        LoadImaged(keys=keys),
        EnsureChannelFirstd(keys=keys),
    ]
    # 2. sampling
    sample_transforms = [
        PreprocessAnisotropic(
            keys=keys,
            clip_values=clip_values[task_id],
            pixdim=spacing[task_id],
            normalize_values=normalize_values[task_id],
            model_mode=mode,
        ),
    ]
    # 3. spatial transforms
    if mode == "train":
        other_transforms = [
            SpatialPadd(keys=["image", "label"],
                        spatial_size=patch_size[task_id]),
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=patch_size[task_id],
                pos=pos_sample_num,
                neg=neg_sample_num,
                num_samples=num_samples,
                image_key="image",
                image_threshold=0,
            ),
            RandZoomd(
                keys=["image", "label"],
                min_zoom=0.9,
                max_zoom=1.2,
                mode=("trilinear", "nearest"),
                align_corners=(True, None),
                prob=0.15,
            ),
            RandGaussianNoised(keys=["image"], std=0.01, prob=0.15),
            RandGaussianSmoothd(
                keys=["image"],
                sigma_x=(0.5, 1.15),
                sigma_y=(0.5, 1.15),
                sigma_z=(0.5, 1.15),
                prob=0.15,
            ),
            RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15),
            RandFlipd(["image", "label"], spatial_axis=[0], prob=0.5),
            RandFlipd(["image", "label"], spatial_axis=[1], prob=0.5),
            RandFlipd(["image", "label"], spatial_axis=[2], prob=0.5),
            CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)),
            EnsureTyped(keys=["image", "label"]),
        ]
    elif mode == "validation":
        other_transforms = [
            CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)),
            EnsureTyped(keys=["image", "label"]),
        ]
    else:
        other_transforms = [
            CastToTyped(keys=["image"], dtype=(np.float32)),
            EnsureTyped(keys=["image"]),
        ]

    all_transforms = load_transforms + sample_transforms + other_transforms
    return Compose(all_transforms)
Exemple #11
0
    def test_invert(self):
        set_determinism(seed=0)
        im_fname, seg_fname = [
            make_nifti_image(i)
            for i in create_test_image_3d(101, 100, 107, noise_max=100)
        ]
        transform = Compose([
            LoadImaged(KEYS),
            AddChanneld(KEYS),
            Orientationd(KEYS, "RPS"),
            Spacingd(KEYS,
                     pixdim=(1.2, 1.01, 0.9),
                     mode=["bilinear", "nearest"],
                     dtype=np.float32),
            ScaleIntensityd("image", minv=1, maxv=10),
            RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(KEYS, prob=0.5),
            RandRotate90d(KEYS, spatial_axes=(1, 2)),
            RandZoomd(KEYS,
                      prob=0.5,
                      min_zoom=0.5,
                      max_zoom=1.1,
                      keep_size=True),
            RandRotated(KEYS,
                        prob=0.5,
                        range_x=np.pi,
                        mode="bilinear",
                        align_corners=True),
            RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"),
            ResizeWithPadOrCropd(KEYS, 100),
            ToTensord(
                "image"
            ),  # test to support both Tensor and Numpy array when inverting
            CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]),
        ])
        data = [{"image": im_fname, "label": seg_fname} for _ in range(12)]

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available(
        ) else 2

        dataset = CacheDataset(data, transform=transform, progress=False)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=5)
        inverter = Invertd(
            keys=["image", "label"],
            transform=transform,
            loader=loader,
            orig_keys="label",
            nearest_interp=True,
            postfix="inverted",
            to_tensor=[True, False],
            device="cpu",
            num_workers=0
            if sys.platform == "darwin" or torch.cuda.is_available() else 2,
        )

        # execute 1 epoch
        for d in loader:
            d = inverter(d)
            # this unit test only covers basic function, test_handler_transform_inverter covers more
            self.assertTupleEqual(d["image"].shape[1:], (1, 100, 100, 100))
            self.assertTupleEqual(d["label"].shape[1:], (1, 100, 100, 100))
            # check the nearest inerpolation mode
            for i in d["image_inverted"]:
                torch.testing.assert_allclose(
                    i.to(torch.uint8).to(torch.float), i.to(torch.float))
                self.assertTupleEqual(i.shape, (1, 100, 101, 107))
            for i in d["label_inverted"]:
                np.testing.assert_allclose(
                    i.astype(np.uint8).astype(np.float32),
                    i.astype(np.float32))
                self.assertTupleEqual(i.shape, (1, 100, 101, 107))

        set_determinism(seed=None)
Exemple #12
0
    def test_value_3d(
        self,
        keys,
        data,
        expected_convert_result,
        expected_zoom_result,
        expected_zoom_keepsize_result,
        expected_flip_result,
        expected_clip_result,
        expected_rotate_result,
    ):
        test_dtype = [torch.float32]
        for dtype in test_dtype:
            data = CastToTyped(keys=["image", "boxes"], dtype=dtype)(data)
            # test ConvertBoxToStandardModed
            transform_convert_mode = ConvertBoxModed(**keys)
            convert_result = transform_convert_mode(data)
            assert_allclose(convert_result["boxes"],
                            expected_convert_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)

            invert_transform_convert_mode = Invertd(
                keys=["boxes"],
                transform=transform_convert_mode,
                orig_keys=["boxes"])
            data_back = invert_transform_convert_mode(convert_result)
            if "boxes_transforms" in data_back:  # if the transform is tracked in dict:
                self.assertEqual(data_back["boxes_transforms"],
                                 [])  # it should be updated
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)

            # test ZoomBoxd
            transform_zoom = ZoomBoxd(image_keys="image",
                                      box_keys="boxes",
                                      box_ref_image_keys="image",
                                      zoom=[0.5, 3, 1.5],
                                      keep_size=False)
            zoom_result = transform_zoom(data)
            self.assertEqual(len(zoom_result["image"].applied_operations), 1)
            assert_allclose(zoom_result["boxes"],
                            expected_zoom_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)
            invert_transform_zoom = Invertd(keys=["image", "boxes"],
                                            transform=transform_zoom,
                                            orig_keys=["image", "boxes"])
            data_back = invert_transform_zoom(zoom_result)
            self.assertEqual(data_back["image"].applied_operations, [])
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)
            assert_allclose(data_back["image"],
                            data["image"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)

            transform_zoom = ZoomBoxd(image_keys="image",
                                      box_keys="boxes",
                                      box_ref_image_keys="image",
                                      zoom=[0.5, 3, 1.5],
                                      keep_size=True)
            zoom_result = transform_zoom(data)
            self.assertEqual(len(zoom_result["image"].applied_operations), 1)
            assert_allclose(zoom_result["boxes"],
                            expected_zoom_keepsize_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)

            # test RandZoomBoxd
            transform_zoom = RandZoomBoxd(
                image_keys="image",
                box_keys="boxes",
                box_ref_image_keys="image",
                prob=1.0,
                min_zoom=(0.3, ) * 3,
                max_zoom=(3.0, ) * 3,
                keep_size=False,
            )
            zoom_result = transform_zoom(data)
            self.assertEqual(len(zoom_result["image"].applied_operations), 1)
            invert_transform_zoom = Invertd(keys=["image", "boxes"],
                                            transform=transform_zoom,
                                            orig_keys=["image", "boxes"])
            data_back = invert_transform_zoom(zoom_result)
            self.assertEqual(data_back["image"].applied_operations, [])
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=0.01)
            assert_allclose(data_back["image"],
                            data["image"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)

            # test AffineBoxToImageCoordinated, AffineBoxToWorldCoordinated
            transform_affine = AffineBoxToImageCoordinated(
                box_keys="boxes", box_ref_image_keys="image")
            if not isinstance(
                    data["image"], MetaTensor
            ):  # metadict should be undefined and it's an exception
                with self.assertRaises(Exception) as context:
                    transform_affine(deepcopy(data))
                self.assertTrue(
                    "Please check whether it is the correct the image meta key."
                    in str(context.exception))

            data["image"] = MetaTensor(
                data["image"],
                meta={
                    "affine": torch.diag(1.0 / torch.Tensor([0.5, 3, 1.5, 1]))
                })
            affine_result = transform_affine(data)
            if "boxes_transforms" in affine_result:
                self.assertEqual(len(affine_result["boxes_transforms"]), 1)
            assert_allclose(affine_result["boxes"],
                            expected_zoom_result,
                            type_test=True,
                            device_test=True,
                            atol=0.01)
            invert_transform_affine = Invertd(keys=["boxes"],
                                              transform=transform_affine,
                                              orig_keys=["boxes"])
            data_back = invert_transform_affine(affine_result)
            if "boxes_transforms" in data_back:
                self.assertEqual(data_back["boxes_transforms"], [])
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=0.01)
            invert_transform_affine = AffineBoxToWorldCoordinated(
                box_keys="boxes", box_ref_image_keys="image")
            data_back = invert_transform_affine(affine_result)
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=0.01)

            # test FlipBoxd
            transform_flip = FlipBoxd(image_keys="image",
                                      box_keys="boxes",
                                      box_ref_image_keys="image",
                                      spatial_axis=[0, 1, 2])
            flip_result = transform_flip(data)
            if "boxes_transforms" in flip_result:
                self.assertEqual(len(flip_result["boxes_transforms"]), 1)
            assert_allclose(flip_result["boxes"],
                            expected_flip_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)
            invert_transform_flip = Invertd(keys=["image", "boxes"],
                                            transform=transform_flip,
                                            orig_keys=["image", "boxes"])
            data_back = invert_transform_flip(flip_result)
            if "boxes_transforms" in data_back:
                self.assertEqual(data_back["boxes_transforms"], [])
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)
            assert_allclose(data_back["image"],
                            data["image"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)

            # test RandFlipBoxd
            for spatial_axis in [(0, ), (1, ), (2, ), (0, 1), (1, 2)]:
                transform_flip = RandFlipBoxd(
                    image_keys="image",
                    box_keys="boxes",
                    box_ref_image_keys="image",
                    prob=1.0,
                    spatial_axis=spatial_axis,
                )
                flip_result = transform_flip(data)
                if "boxes_transforms" in flip_result:
                    self.assertEqual(len(flip_result["boxes_transforms"]), 1)
                invert_transform_flip = Invertd(keys=["image", "boxes"],
                                                transform=transform_flip,
                                                orig_keys=["image", "boxes"])
                data_back = invert_transform_flip(flip_result)
                if "boxes_transforms" in data_back:
                    self.assertEqual(data_back["boxes_transforms"], [])
                assert_allclose(data_back["boxes"],
                                data["boxes"],
                                type_test=False,
                                device_test=False,
                                atol=1e-3)
                assert_allclose(data_back["image"],
                                data["image"],
                                type_test=False,
                                device_test=False,
                                atol=1e-3)

            # test ClipBoxToImaged
            transform_clip = ClipBoxToImaged(box_keys="boxes",
                                             box_ref_image_keys="image",
                                             label_keys=["labels", "scores"],
                                             remove_empty=True)
            clip_result = transform_clip(data)
            assert_allclose(clip_result["boxes"],
                            expected_clip_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)
            assert_allclose(clip_result["labels"],
                            data["labels"][1:],
                            type_test=True,
                            device_test=True,
                            atol=1e-3)
            assert_allclose(clip_result["scores"],
                            data["scores"][1:],
                            type_test=True,
                            device_test=True,
                            atol=1e-3)

            transform_clip = ClipBoxToImaged(
                box_keys="boxes",
                box_ref_image_keys="image",
                label_keys=[],
                remove_empty=True)  # corner case when label_keys is empty
            clip_result = transform_clip(data)
            assert_allclose(clip_result["boxes"],
                            expected_clip_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)

            # test RandCropBoxByPosNegLabeld
            transform_crop = RandCropBoxByPosNegLabeld(
                image_keys="image",
                box_keys="boxes",
                label_keys=["labels", "scores"],
                spatial_size=2,
                num_samples=3)
            crop_result = transform_crop(data)
            assert len(crop_result) == 3
            for ll in range(3):
                assert_allclose(
                    crop_result[ll]["boxes"].shape[0],
                    crop_result[ll]["labels"].shape[0],
                    type_test=True,
                    device_test=True,
                    atol=1e-3,
                )
                assert_allclose(
                    crop_result[ll]["boxes"].shape[0],
                    crop_result[ll]["scores"].shape[0],
                    type_test=True,
                    device_test=True,
                    atol=1e-3,
                )

            # test RotateBox90d
            transform_rotate = RotateBox90d(image_keys="image",
                                            box_keys="boxes",
                                            box_ref_image_keys="image",
                                            k=1,
                                            spatial_axes=[0, 1])
            rotate_result = transform_rotate(data)
            self.assertEqual(len(rotate_result["image"].applied_operations), 1)
            assert_allclose(rotate_result["boxes"],
                            expected_rotate_result,
                            type_test=True,
                            device_test=True,
                            atol=1e-3)
            invert_transform_rotate = Invertd(keys=["image", "boxes"],
                                              transform=transform_rotate,
                                              orig_keys=["image", "boxes"])
            data_back = invert_transform_rotate(rotate_result)
            self.assertEqual(data_back["image"].applied_operations, [])
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)
            assert_allclose(data_back["image"],
                            data["image"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)

            transform_rotate = RandRotateBox90d(image_keys="image",
                                                box_keys="boxes",
                                                box_ref_image_keys="image",
                                                prob=1.0,
                                                max_k=3,
                                                spatial_axes=[0, 1])
            rotate_result = transform_rotate(data)
            self.assertEqual(len(rotate_result["image"].applied_operations), 1)
            invert_transform_rotate = Invertd(keys=["image", "boxes"],
                                              transform=transform_rotate,
                                              orig_keys=["image", "boxes"])
            data_back = invert_transform_rotate(rotate_result)
            self.assertEqual(data_back["image"].applied_operations, [])
            assert_allclose(data_back["boxes"],
                            data["boxes"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)
            assert_allclose(data_back["image"],
                            data["image"],
                            type_test=False,
                            device_test=False,
                            atol=1e-3)
Exemple #13
0
    def test_invert(self):
        set_determinism(seed=0)
        im_fname, seg_fname = (
            make_nifti_image(i)
            for i in create_test_image_3d(101, 100, 107, noise_max=100))
        transform = Compose([
            LoadImaged(KEYS),
            AddChanneld(KEYS),
            Orientationd(KEYS, "RPS"),
            Spacingd(KEYS,
                     pixdim=(1.2, 1.01, 0.9),
                     mode=["bilinear", "nearest"],
                     dtype=np.float32),
            ScaleIntensityd("image", minv=1, maxv=10),
            RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(KEYS, prob=0.5),
            RandRotate90d(KEYS, spatial_axes=(1, 2)),
            RandZoomd(KEYS,
                      prob=0.5,
                      min_zoom=0.5,
                      max_zoom=1.1,
                      keep_size=True),
            RandRotated(KEYS,
                        prob=0.5,
                        range_x=np.pi,
                        mode="bilinear",
                        align_corners=True,
                        dtype=np.float64),
            RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"),
            ResizeWithPadOrCropd(KEYS, 100),
            # test EnsureTensor for complicated dict data and invert it
            CopyItemsd(PostFix.meta("image"), times=1, names="test_dict"),
            # test to support Tensor, Numpy array and dictionary when inverting
            EnsureTyped(keys=["image", "test_dict"]),
            ToTensord("image"),
            CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]),
            CopyItemsd("label",
                       times=2,
                       names=["label_inverted", "label_inverted1"]),
            CopyItemsd("image",
                       times=2,
                       names=["image_inverted", "image_inverted1"]),
        ])
        data = [{"image": im_fname, "label": seg_fname} for _ in range(12)]

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available(
        ) else 2

        dataset = CacheDataset(data, transform=transform, progress=False)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=5)
        inverter = Invertd(
            # `image` was not copied, invert the original value directly
            keys=["image_inverted", "label_inverted", "test_dict"],
            transform=transform,
            orig_keys=["label", "label", "test_dict"],
            meta_keys=[
                PostFix.meta("image_inverted"),
                PostFix.meta("label_inverted"), None
            ],
            orig_meta_keys=[
                PostFix.meta("label"),
                PostFix.meta("label"), None
            ],
            nearest_interp=True,
            to_tensor=[True, False, False],
            device="cpu",
        )

        inverter_1 = Invertd(
            # `image` was not copied, invert the original value directly
            keys=["image_inverted1", "label_inverted1"],
            transform=transform,
            orig_keys=["image", "image"],
            meta_keys=[
                PostFix.meta("image_inverted1"),
                PostFix.meta("label_inverted1")
            ],
            orig_meta_keys=[PostFix.meta("image"),
                            PostFix.meta("image")],
            nearest_interp=[True, False],
            to_tensor=[True, True],
            device="cpu",
        )

        expected_keys = [
            "image",
            "image_inverted",
            "image_inverted1",
            PostFix.meta("image_inverted1"),
            PostFix.meta("image_inverted"),
            PostFix.meta("image"),
            "image_transforms",
            "label",
            "label_inverted",
            "label_inverted1",
            PostFix.meta("label_inverted1"),
            PostFix.meta("label_inverted"),
            PostFix.meta("label"),
            "label_transforms",
            "test_dict",
            "test_dict_transforms",
        ]
        # execute 1 epoch
        for d in loader:
            d = decollate_batch(d)
            for item in d:
                item = inverter(item)
                item = inverter_1(item)

                self.assertListEqual(sorted(item), expected_keys)
                self.assertTupleEqual(item["image"].shape[1:], (100, 100, 100))
                self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100))
                # check the nearest interpolation mode
                i = item["image_inverted"]
                torch.testing.assert_allclose(
                    i.to(torch.uint8).to(torch.float), i.to(torch.float))
                self.assertTupleEqual(i.shape[1:], (100, 101, 107))
                i = item["label_inverted"]
                torch.testing.assert_allclose(
                    i.to(torch.uint8).to(torch.float), i.to(torch.float))
                self.assertTupleEqual(i.shape[1:], (100, 101, 107))
                # test inverted test_dict
                self.assertTrue(
                    isinstance(item["test_dict"]["affine"], np.ndarray))
                self.assertTrue(
                    isinstance(item["test_dict"]["filename_or_obj"], str))

                # check the case that different items use different interpolation mode to invert transforms
                d = item["image_inverted1"]
                # if the interpolation mode is nearest, accumulated diff should be smaller than 1
                self.assertLess(
                    torch.sum(
                        d.to(torch.float) -
                        d.to(torch.uint8).to(torch.float)).item(), 1.0)
                self.assertTupleEqual(d.shape, (1, 100, 101, 107))

                d = item["label_inverted1"]
                # if the interpolation mode is not nearest, accumulated diff should be greater than 10000
                self.assertGreater(
                    torch.sum(
                        d.to(torch.float) -
                        d.to(torch.uint8).to(torch.float)).item(), 10000.0)
                self.assertTupleEqual(d.shape, (1, 100, 101, 107))

        # check labels match
        reverted = item["label_inverted"].detach().cpu().numpy().astype(
            np.int32)
        original = LoadImaged(KEYS)(data[-1])["label"]
        n_good = np.sum(np.isclose(reverted, original, atol=1e-3))
        reverted_name = item[PostFix.meta("label_inverted")]["filename_or_obj"]
        original_name = data[-1]["label"]
        self.assertEqual(reverted_name, original_name)
        print("invert diff", reverted.size - n_good)
        # 25300: 2 workers (cpu, non-macos)
        # 1812: 0 workers (gpu or macos)
        # 1821: windows torch 1.10.0
        self.assertTrue((reverted.size - n_good) in (34007, 1812, 1821),
                        f"diff.  {reverted.size - n_good}")

        set_determinism(seed=None)
    RandSpatialCropd(
        keys=("input", "mask"),
        roi_size=(cfg.img_size[0], cfg.img_size[1]),
        random_size=False,
    ),
    RandScaleIntensityd(keys="input", factors=(-0.2, 0.2), prob=0.5),
    RandShiftIntensityd(keys="input", offsets=(-51, 51), prob=0.5),
    RandLambdad(keys="input", func=lambda x: 255 - x, prob=0.5),
    RandCoarseDropoutd(
        keys=("input", "mask"),
        holes=8,
        spatial_size=(1, 1),
        max_spatial_size=(102, 102),
        prob=0.5,
    ),
    CastToTyped(keys="input", dtype=np.float32),
    NormalizeIntensityd(keys="input", nonzero=False),
    Lambdad(keys="input", func=lambda x: x.clip(-20, 20)),
    EnsureTyped(keys=("input", "mask")),
])

cfg.val_aug = Compose([
    Resized(
        keys=("input", "mask"),
        spatial_size=1120,
        size_mode="longest",
        mode="bilinear",
        align_corners=False,
    ),
    SpatialPadd(keys=("input", "mask"), spatial_size=(1120, 1120)),
    CenterSpatialCropd(keys=("input", "mask"),
Exemple #15
0
def run(file, inputMount, outputMount):
    inputPath = os.path.join(inputMount, file.split('/')[-1])
    outputPath = os.path.join(outputMount, file.split('/')[-1])
    print("attempting to infer cell clusters")
    try:

        # open the image only to find its size
        with Image.open(inputPath) as input_img:
            input_bbox = input_img.getbbox()
            input_width = input_bbox[2] - input_bbox[0]
            input_height = input_bbox[3] - input_bbox[1]
            #print('input image size:',input_width,input_height)

        # instantiate the model
        # standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
        #print('checking for cuda device')
        device = torch.device('cuda:0')
        model = monai.networks.nets.UNet(dimensions=2,
                                         in_channels=3,
                                         out_channels=2,
                                         channels=(16, 32, 64, 128, 256),
                                         strides=(2, 2, 2, 2),
                                         num_res_units=2,
                                         norm=Norm.BATCH).to(device)

        print('attempting load of pretrained network from /tmp directory')
        # read in the pretrained model
        model.load_state_dict(
            torch.load('/tmp/unet_368x368_segment_model.pth'))
        #print('loading complete')
        model.eval()
        NETWORK_IMAGE_SIZE = 368

        # define xforms in MONAI to prepare imagery for PyTorch inferencing
        infer_transforms = Compose([
            LoadPNGd(keys=['image']),
            AsChannelFirstd(keys=['image']),
            Resized(keys=['image'],
                    spatial_size=(NETWORK_IMAGE_SIZE, NETWORK_IMAGE_SIZE),
                    mode='bilinear',
                    align_corners=False),
            CastToTyped(keys=['image'], dtype='float32'),
            ScaleIntensityd(keys=['image'], minv=0.0, maxv=1.0),
            ToTensord(keys=['image'])
        ])

        # create and load a Monai dataset composed of the single image, because the transforms
        # are performed automatically in MONAI. A spec for the image is needed by the Dataset definition
        infer_files = [{'image': inputPath}]
        infer_ds = monai.data.Dataset(infer_files, transform=infer_transforms)
        infer_loader = monai.data.DataLoader(infer_ds, batch_size=1)

        # get the transformed single image out of the dataset
        input_data = monai.utils.misc.first(infer_loader)

        # move the tensor to the GPU
        input_tensor = input_data['image'].to(device)

        # run the forward prediction
        predict_tensor = model(input_tensor)
        #print('inference complete. preparing output')

        # get the result back from the GPU and drop the first dimension
        infer_array = predict_tensor.detach().cpu().squeeze()

        # rearrange to num channels last and make it a single channel binary image
        pred_array = torch.argmax(np.transpose(infer_array, (1, 2, 0)), dim=2)

        # convert type back from torch to numpy
        prediction = pred_array.numpy()

        # write the file out in a viewable way, the output image is resized to match the
        # input image size for convenience, even though inferencing is always done at the
        # size of the pretrained network
        outimg = Image.fromarray(prediction.astype('uint8') * 255)
        resized = outimg.resize((input_width, input_height))
        print('saving segmentation to:', outputPath)
        resized.save(outputPath)

    except OSError:
        print("cannot create inference image for", file)
    def test_invert(self):
        set_determinism(seed=0)
        im_fname, seg_fname = [
            make_nifti_image(i)
            for i in create_test_image_3d(101, 100, 107, noise_max=100)
        ]
        transform = Compose([
            LoadImaged(KEYS),
            AddChanneld(KEYS),
            Orientationd(KEYS, "RPS"),
            Spacingd(KEYS,
                     pixdim=(1.2, 1.01, 0.9),
                     mode=["bilinear", "nearest"],
                     dtype=np.float32),
            ScaleIntensityd("image", minv=1, maxv=10),
            RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(KEYS, prob=0.5),
            RandRotate90d(KEYS, spatial_axes=(1, 2)),
            RandZoomd(KEYS,
                      prob=0.5,
                      min_zoom=0.5,
                      max_zoom=1.1,
                      keep_size=True),
            RandRotated(KEYS,
                        prob=0.5,
                        range_x=np.pi,
                        mode="bilinear",
                        align_corners=True),
            RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"),
            ResizeWithPadOrCropd(KEYS, 100),
            ToTensord(
                "image"
            ),  # test to support both Tensor and Numpy array when inverting
            CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]),
        ])
        data = [{"image": im_fname, "label": seg_fname} for _ in range(12)]

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available(
        ) else 2

        dataset = CacheDataset(data, transform=transform, progress=False)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=5)

        # set up engine
        def _train_func(engine, batch):
            self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100))
            engine.state.output = batch
            engine.fire_event(IterationEvents.MODEL_COMPLETED)
            return engine.state.output

        engine = Engine(_train_func)
        engine.register_events(*IterationEvents)

        # set up testing handler
        TransformInverter(
            transform=transform,
            loader=loader,
            output_keys=["image", "label"],
            batch_keys="label",
            nearest_interp=True,
            postfix="inverted1",
            to_tensor=[True, False],
            device="cpu",
            num_workers=0
            if sys.platform == "darwin" or torch.cuda.is_available() else 2,
        ).attach(engine)

        # test different nearest interpolation values
        TransformInverter(
            transform=transform,
            loader=loader,
            output_keys=["image", "label"],
            batch_keys="image",
            nearest_interp=[True, False],
            post_func=[lambda x: x + 10, lambda x: x],
            postfix="inverted2",
            num_workers=0
            if sys.platform == "darwin" or torch.cuda.is_available() else 2,
        ).attach(engine)

        engine.run(loader, max_epochs=1)
        set_determinism(seed=None)
        self.assertTupleEqual(engine.state.output["image"].shape,
                              (2, 1, 100, 100, 100))
        self.assertTupleEqual(engine.state.output["label"].shape,
                              (2, 1, 100, 100, 100))
        # check the nearest inerpolation mode
        for i in engine.state.output["image_inverted1"]:
            torch.testing.assert_allclose(
                i.to(torch.uint8).to(torch.float), i.to(torch.float))
            self.assertTupleEqual(i.shape, (1, 100, 101, 107))
        for i in engine.state.output["label_inverted1"]:
            np.testing.assert_allclose(
                i.astype(np.uint8).astype(np.float32), i.astype(np.float32))
            self.assertTupleEqual(i.shape, (1, 100, 101, 107))

        # check labels match
        reverted = engine.state.output["label_inverted1"][-1].astype(np.int32)
        original = LoadImaged(KEYS)(data[-1])["label"]
        n_good = np.sum(np.isclose(reverted, original, atol=1e-3))
        reverted_name = engine.state.output["label_meta_dict"][
            "filename_or_obj"][-1]
        original_name = data[-1]["label"]
        self.assertEqual(reverted_name, original_name)
        print("invert diff", reverted.size - n_good)
        # 25300: 2 workers (cpu, non-macos)
        # 1812: 0 workers (gpu or macos)
        # 1824: torch 1.5.1
        self.assertTrue((reverted.size - n_good) in (25300, 1812, 1824),
                        "diff. in 3 possible values")

        # check the case that different items use different interpolation mode to invert transforms
        for i in engine.state.output["image_inverted2"]:
            # if the interpolation mode is nearest, accumulated diff should be smaller than 1
            self.assertLess(
                torch.sum(
                    i.to(torch.float) -
                    i.to(torch.uint8).to(torch.float)).item(), 1.0)
            self.assertTupleEqual(i.shape, (1, 100, 101, 107))

        for i in engine.state.output["label_inverted2"]:
            # if the interpolation mode is not nearest, accumulated diff should be greater than 10000
            self.assertGreater(
                torch.sum(
                    i.to(torch.float) -
                    i.to(torch.uint8).to(torch.float)).item(), 10000.0)
            self.assertTupleEqual(i.shape, (1, 100, 101, 107))
Exemple #17
0
def main():
    parser = argparse.ArgumentParser(description="training")
    parser.add_argument(
        "--checkpoint",
        type=str,
        default=None,
        help="checkpoint full path",
    )
    parser.add_argument(
        "--factor_ram_cost",
        default=0.0,
        type=float,
        help="factor to determine RAM cost in the searched architecture",
    )
    parser.add_argument(
        "--fold",
        action="store",
        required=True,
        help="fold index in N-fold cross-validation",
    )
    parser.add_argument(
        "--json",
        action="store",
        required=True,
        help="full path of .json file",
    )
    parser.add_argument(
        "--json_key",
        action="store",
        required=True,
        help="selected key in .json data list",
    )
    parser.add_argument(
        "--local_rank",
        required=int,
        help="local process rank",
    )
    parser.add_argument(
        "--num_folds",
        action="store",
        required=True,
        help="number of folds in cross-validation",
    )
    parser.add_argument(
        "--output_root",
        action="store",
        required=True,
        help="output root",
    )
    parser.add_argument(
        "--root",
        action="store",
        required=True,
        help="data root",
    )
    args = parser.parse_args()

    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    if not os.path.exists(args.output_root):
        os.makedirs(args.output_root, exist_ok=True)

    amp = True
    determ = True
    factor_ram_cost = args.factor_ram_cost
    fold = int(args.fold)
    input_channels = 1
    learning_rate = 0.025
    learning_rate_arch = 0.001
    learning_rate_milestones = np.array([0.4, 0.8])
    num_images_per_batch = 1
    num_epochs = 1430  # around 20k iteration
    num_epochs_per_validation = 100
    num_epochs_warmup = 715
    num_folds = int(args.num_folds)
    num_patches_per_image = 1
    num_sw_batch_size = 6
    output_classes = 3
    overlap_ratio = 0.625
    patch_size = (96, 96, 96)
    patch_size_valid = (96, 96, 96)
    spacing = [1.0, 1.0, 1.0]

    print("factor_ram_cost", factor_ram_cost)

    # deterministic training
    if determ:
        set_determinism(seed=0)

    # initialize the distributed training process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    # dist.barrier()
    world_size = dist.get_world_size()

    with open(args.json, "r") as f:
        json_data = json.load(f)

    split = len(json_data[args.json_key]) // num_folds
    list_train = json_data[args.json_key][:(
        split * fold)] + json_data[args.json_key][(split * (fold + 1)):]
    list_valid = json_data[args.json_key][(split * fold):(split * (fold + 1))]

    # training data
    files = []
    for _i in range(len(list_train)):
        str_img = os.path.join(args.root, list_train[_i]["image"])
        str_seg = os.path.join(args.root, list_train[_i]["label"])

        if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)):
            continue

        files.append({"image": str_img, "label": str_seg})
    train_files = files

    random.shuffle(train_files)

    train_files_w = train_files[:len(train_files) // 2]
    train_files_w = partition_dataset(data=train_files_w,
                                      shuffle=True,
                                      num_partitions=world_size,
                                      even_divisible=True)[dist.get_rank()]
    print("train_files_w:", len(train_files_w))

    train_files_a = train_files[len(train_files) // 2:]
    train_files_a = partition_dataset(data=train_files_a,
                                      shuffle=True,
                                      num_partitions=world_size,
                                      even_divisible=True)[dist.get_rank()]
    print("train_files_a:", len(train_files_a))

    # validation data
    files = []
    for _i in range(len(list_valid)):
        str_img = os.path.join(args.root, list_valid[_i]["image"])
        str_seg = os.path.join(args.root, list_valid[_i]["label"])

        if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)):
            continue

        files.append({"image": str_img, "label": str_seg})
    val_files = files
    val_files = partition_dataset(data=val_files,
                                  shuffle=False,
                                  num_partitions=world_size,
                                  even_divisible=False)[dist.get_rank()]
    print("val_files:", len(val_files))

    # network architecture
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)

    train_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"],
                 pixdim=spacing,
                 mode=("bilinear", "nearest"),
                 align_corners=(True, True)),
        CastToTyped(keys=["image"], dtype=(torch.float32)),
        ScaleIntensityRanged(keys=["image"],
                             a_min=-87.0,
                             a_max=199.0,
                             b_min=0.0,
                             b_max=1.0,
                             clip=True),
        CastToTyped(keys=["image", "label"], dtype=(np.float16, np.uint8)),
        CopyItemsd(keys=["label"], times=1, names=["label4crop"]),
        Lambdad(
            keys=["label4crop"],
            func=lambda x: np.concatenate(tuple([
                ndimage.binary_dilation(
                    (x == _k).astype(x.dtype), iterations=48).astype(x.dtype)
                for _k in range(output_classes)
            ]),
                                          axis=0),
            overwrite=True,
        ),
        EnsureTyped(keys=["image", "label"]),
        CastToTyped(keys=["image"], dtype=(torch.float32)),
        SpatialPadd(keys=["image", "label", "label4crop"],
                    spatial_size=patch_size,
                    mode=["reflect", "constant", "constant"]),
        RandCropByLabelClassesd(keys=["image", "label"],
                                label_key="label4crop",
                                num_classes=output_classes,
                                ratios=[
                                    1,
                                ] * output_classes,
                                spatial_size=patch_size,
                                num_samples=num_patches_per_image),
        Lambdad(keys=["label4crop"], func=lambda x: 0),
        RandRotated(keys=["image", "label"],
                    range_x=0.3,
                    range_y=0.3,
                    range_z=0.3,
                    mode=["bilinear", "nearest"],
                    prob=0.2),
        RandZoomd(keys=["image", "label"],
                  min_zoom=0.8,
                  max_zoom=1.2,
                  mode=["trilinear", "nearest"],
                  align_corners=[True, None],
                  prob=0.16),
        RandGaussianSmoothd(keys=["image"],
                            sigma_x=(0.5, 1.15),
                            sigma_y=(0.5, 1.15),
                            sigma_z=(0.5, 1.15),
                            prob=0.15),
        RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.5),
        RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
        RandGaussianNoised(keys=["image"], std=0.01, prob=0.15),
        RandFlipd(keys=["image", "label"], spatial_axis=0, prob=0.5),
        RandFlipd(keys=["image", "label"], spatial_axis=1, prob=0.5),
        RandFlipd(keys=["image", "label"], spatial_axis=2, prob=0.5),
        CastToTyped(keys=["image", "label"],
                    dtype=(torch.float32, torch.uint8)),
        ToTensord(keys=["image", "label"]),
    ])

    val_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"],
                 pixdim=spacing,
                 mode=("bilinear", "nearest"),
                 align_corners=(True, True)),
        CastToTyped(keys=["image"], dtype=(torch.float32)),
        ScaleIntensityRanged(keys=["image"],
                             a_min=-87.0,
                             a_max=199.0,
                             b_min=0.0,
                             b_max=1.0,
                             clip=True),
        CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)),
        EnsureTyped(keys=["image", "label"]),
        ToTensord(keys=["image", "label"])
    ])

    train_ds_a = monai.data.CacheDataset(data=train_files_a,
                                         transform=train_transforms,
                                         cache_rate=1.0,
                                         num_workers=8)
    train_ds_w = monai.data.CacheDataset(data=train_files_w,
                                         transform=train_transforms,
                                         cache_rate=1.0,
                                         num_workers=8)
    val_ds = monai.data.CacheDataset(data=val_files,
                                     transform=val_transforms,
                                     cache_rate=1.0,
                                     num_workers=2)

    # monai.data.Dataset can be used as alternatives when debugging or RAM space is limited.
    # train_ds_a = monai.data.Dataset(data=train_files_a, transform=train_transforms)
    # train_ds_w = monai.data.Dataset(data=train_files_w, transform=train_transforms)
    # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)

    train_loader_a = ThreadDataLoader(train_ds_a,
                                      num_workers=0,
                                      batch_size=num_images_per_batch,
                                      shuffle=True)
    train_loader_w = ThreadDataLoader(train_ds_w,
                                      num_workers=0,
                                      batch_size=num_images_per_batch,
                                      shuffle=True)
    val_loader = ThreadDataLoader(val_ds,
                                  num_workers=0,
                                  batch_size=1,
                                  shuffle=False)

    # DataLoader can be used as alternatives when ThreadDataLoader is less efficient.
    # train_loader_a = DataLoader(train_ds_a, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())
    # train_loader_w = DataLoader(train_ds_w, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())
    # val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available())

    dints_space = monai.networks.nets.TopologySearch(
        channel_mul=0.5,
        num_blocks=12,
        num_depths=4,
        use_downsample=True,
        device=device,
    )

    model = monai.networks.nets.DiNTS(
        dints_space=dints_space,
        in_channels=input_channels,
        num_classes=output_classes,
        use_downsample=True,
    )

    model = model.to(device)

    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    post_pred = Compose(
        [EnsureType(),
         AsDiscrete(argmax=True, to_onehot=output_classes)])
    post_label = Compose([EnsureType(), AsDiscrete(to_onehot=output_classes)])

    # loss function
    loss_func = monai.losses.DiceCELoss(
        include_background=False,
        to_onehot_y=True,
        softmax=True,
        squared_pred=True,
        batch=True,
        smooth_nr=0.00001,
        smooth_dr=0.00001,
    )

    # optimizer
    optimizer = torch.optim.SGD(model.weight_parameters(),
                                lr=learning_rate * world_size,
                                momentum=0.9,
                                weight_decay=0.00004)
    arch_optimizer_a = torch.optim.Adam([dints_space.log_alpha_a],
                                        lr=learning_rate_arch * world_size,
                                        betas=(0.5, 0.999),
                                        weight_decay=0.0)
    arch_optimizer_c = torch.optim.Adam([dints_space.log_alpha_c],
                                        lr=learning_rate_arch * world_size,
                                        betas=(0.5, 0.999),
                                        weight_decay=0.0)

    print()

    if torch.cuda.device_count() > 1:
        if dist.get_rank() == 0:
            print("Let's use", torch.cuda.device_count(), "GPUs!")

        model = DistributedDataParallel(model,
                                        device_ids=[device],
                                        find_unused_parameters=True)

    if args.checkpoint != None and os.path.isfile(args.checkpoint):
        print("[info] fine-tuning pre-trained checkpoint {0:s}".format(
            args.checkpoint))
        model.load_state_dict(torch.load(args.checkpoint, map_location=device))
        torch.cuda.empty_cache()
    else:
        print("[info] training from scratch")

    # amp
    if amp:
        from torch.cuda.amp import autocast, GradScaler
        scaler = GradScaler()
        if dist.get_rank() == 0:
            print("[info] amp enabled")

    # start a typical PyTorch training
    val_interval = num_epochs_per_validation
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    idx_iter = 0
    metric_values = list()

    if dist.get_rank() == 0:
        writer = SummaryWriter(
            log_dir=os.path.join(args.output_root, "Events"))

        with open(os.path.join(args.output_root, "accuracy_history.csv"),
                  "a") as f:
            f.write("epoch\tmetric\tloss\tlr\ttime\titer\n")

    dataloader_a_iterator = iter(train_loader_a)

    start_time = time.time()
    for epoch in range(num_epochs):
        decay = 0.5**np.sum([
            (epoch - num_epochs_warmup) /
            (num_epochs - num_epochs_warmup) > learning_rate_milestones
        ])
        lr = learning_rate * decay
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

        if dist.get_rank() == 0:
            print("-" * 10)
            print(f"epoch {epoch + 1}/{num_epochs}")
            print("learning rate is set to {}".format(lr))

        model.train()
        epoch_loss = 0
        loss_torch = torch.zeros(2, dtype=torch.float, device=device)
        epoch_loss_arch = 0
        loss_torch_arch = torch.zeros(2, dtype=torch.float, device=device)
        step = 0

        for batch_data in train_loader_w:
            step += 1
            inputs, labels = batch_data["image"].to(
                device), batch_data["label"].to(device)
            if world_size == 1:
                for _ in model.weight_parameters():
                    _.requires_grad = True
            else:
                for _ in model.module.weight_parameters():
                    _.requires_grad = True
            dints_space.log_alpha_a.requires_grad = False
            dints_space.log_alpha_c.requires_grad = False

            optimizer.zero_grad()

            if amp:
                with autocast():
                    outputs = model(inputs)
                    if output_classes == 2:
                        loss = loss_func(torch.flip(outputs, dims=[1]),
                                         1 - labels)
                    else:
                        loss = loss_func(outputs, labels)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(inputs)
                if output_classes == 2:
                    loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels)
                else:
                    loss = loss_func(outputs, labels)
                loss.backward()
                optimizer.step()

            epoch_loss += loss.item()
            loss_torch[0] += loss.item()
            loss_torch[1] += 1.0
            epoch_len = len(train_loader_w)
            idx_iter += 1

            if dist.get_rank() == 0:
                print("[{0}] ".format(str(datetime.now())[:19]) +
                      f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
                writer.add_scalar("train_loss", loss.item(),
                                  epoch_len * epoch + step)

            if epoch < num_epochs_warmup:
                continue

            try:
                sample_a = next(dataloader_a_iterator)
            except StopIteration:
                dataloader_a_iterator = iter(train_loader_a)
                sample_a = next(dataloader_a_iterator)
            inputs_search, labels_search = sample_a["image"].to(
                device), sample_a["label"].to(device)
            if world_size == 1:
                for _ in model.weight_parameters():
                    _.requires_grad = False
            else:
                for _ in model.module.weight_parameters():
                    _.requires_grad = False
            dints_space.log_alpha_a.requires_grad = True
            dints_space.log_alpha_c.requires_grad = True

            # linear increase topology and RAM loss
            entropy_alpha_c = torch.tensor(0.).to(device)
            entropy_alpha_a = torch.tensor(0.).to(device)
            ram_cost_full = torch.tensor(0.).to(device)
            ram_cost_usage = torch.tensor(0.).to(device)
            ram_cost_loss = torch.tensor(0.).to(device)
            topology_loss = torch.tensor(0.).to(device)

            probs_a, arch_code_prob_a = dints_space.get_prob_a(child=True)
            entropy_alpha_a = -((probs_a) * torch.log(probs_a + 1e-5)).mean()
            entropy_alpha_c = -(F.softmax(dints_space.log_alpha_c, dim=-1) * \
                F.log_softmax(dints_space.log_alpha_c, dim=-1)).mean()
            topology_loss = dints_space.get_topology_entropy(probs_a)

            ram_cost_full = dints_space.get_ram_cost_usage(inputs.shape,
                                                           full=True)
            ram_cost_usage = dints_space.get_ram_cost_usage(inputs.shape)
            ram_cost_loss = torch.abs(factor_ram_cost -
                                      ram_cost_usage / ram_cost_full)

            arch_optimizer_a.zero_grad()
            arch_optimizer_c.zero_grad()

            combination_weights = (epoch - num_epochs_warmup) / (
                num_epochs - num_epochs_warmup)

            if amp:
                with autocast():
                    outputs_search = model(inputs_search)
                    if output_classes == 2:
                        loss = loss_func(torch.flip(outputs_search, dims=[1]),
                                         1 - labels_search)
                    else:
                        loss = loss_func(outputs_search, labels_search)

                    loss += combination_weights * ((entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \
                                                    + 0.001 * topology_loss)

                scaler.scale(loss).backward()
                scaler.step(arch_optimizer_a)
                scaler.step(arch_optimizer_c)
                scaler.update()
            else:
                outputs_search = model(inputs_search)
                if output_classes == 2:
                    loss = loss_func(torch.flip(outputs_search, dims=[1]),
                                     1 - labels_search)
                else:
                    loss = loss_func(outputs_search, labels_search)

                loss += 1.0 * (combination_weights * (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \
                                + 0.001 * topology_loss)

                loss.backward()
                arch_optimizer_a.step()
                arch_optimizer_c.step()

            epoch_loss_arch += loss.item()
            loss_torch_arch[0] += loss.item()
            loss_torch_arch[1] += 1.0

            if dist.get_rank() == 0:
                print(
                    "[{0}] ".format(str(datetime.now())[:19]) +
                    f"{step}/{epoch_len}, train_loss_arch: {loss.item():.4f}")
                writer.add_scalar("train_loss_arch", loss.item(),
                                  epoch_len * epoch + step)

        # synchronizes all processes and reduce results
        dist.barrier()
        dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM)
        loss_torch = loss_torch.tolist()
        loss_torch_arch = loss_torch_arch.tolist()
        if dist.get_rank() == 0:
            loss_torch_epoch = loss_torch[0] / loss_torch[1]
            print(
                f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}"
            )

            if epoch >= num_epochs_warmup:
                loss_torch_arch_epoch = loss_torch_arch[0] / loss_torch_arch[1]
                print(
                    f"epoch {epoch + 1} average arch loss: {loss_torch_arch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}"
                )

        if (epoch + 1) % val_interval == 0:
            torch.cuda.empty_cache()
            model.eval()
            with torch.no_grad():
                metric = torch.zeros((output_classes - 1) * 2,
                                     dtype=torch.float,
                                     device=device)
                metric_sum = 0.0
                metric_count = 0
                metric_mat = []
                val_images = None
                val_labels = None
                val_outputs = None

                _index = 0
                for val_data in val_loader:
                    val_images = val_data["image"].to(device)
                    val_labels = val_data["label"].to(device)

                    roi_size = patch_size_valid
                    sw_batch_size = num_sw_batch_size

                    if amp:
                        with torch.cuda.amp.autocast():
                            pred = sliding_window_inference(
                                val_images,
                                roi_size,
                                sw_batch_size,
                                lambda x: model(x),
                                mode="gaussian",
                                overlap=overlap_ratio,
                            )
                    else:
                        pred = sliding_window_inference(
                            val_images,
                            roi_size,
                            sw_batch_size,
                            lambda x: model(x),
                            mode="gaussian",
                            overlap=overlap_ratio,
                        )
                    val_outputs = pred

                    val_outputs = post_pred(val_outputs[0, ...])
                    val_outputs = val_outputs[None, ...]
                    val_labels = post_label(val_labels[0, ...])
                    val_labels = val_labels[None, ...]

                    value = compute_meandice(y_pred=val_outputs,
                                             y=val_labels,
                                             include_background=False)

                    print(_index + 1, "/", len(val_loader), value)

                    metric_count += len(value)
                    metric_sum += value.sum().item()
                    metric_vals = value.cpu().numpy()
                    if len(metric_mat) == 0:
                        metric_mat = metric_vals
                    else:
                        metric_mat = np.concatenate((metric_mat, metric_vals),
                                                    axis=0)

                    for _c in range(output_classes - 1):
                        val0 = torch.nan_to_num(value[0, _c], nan=0.0)
                        val1 = 1.0 - torch.isnan(value[0, 0]).float()
                        metric[2 * _c] += val0 * val1
                        metric[2 * _c + 1] += val1

                    _index += 1

                # synchronizes all processes and reduce results
                dist.barrier()
                dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM)
                metric = metric.tolist()
                if dist.get_rank() == 0:
                    for _c in range(output_classes - 1):
                        print(
                            "evaluation metric - class {0:d}:".format(_c + 1),
                            metric[2 * _c] / metric[2 * _c + 1])
                    avg_metric = 0
                    for _c in range(output_classes - 1):
                        avg_metric += metric[2 * _c] / metric[2 * _c + 1]
                    avg_metric = avg_metric / float(output_classes - 1)
                    print("avg_metric", avg_metric)

                    if avg_metric > best_metric:
                        best_metric = avg_metric
                        best_metric_epoch = epoch + 1
                        best_metric_iterations = idx_iter

                    node_a_d, arch_code_a_d, arch_code_c_d, arch_code_a_max_d = dints_space.decode(
                    )
                    torch.save(
                        {
                            "node_a": node_a_d,
                            "arch_code_a": arch_code_a_d,
                            "arch_code_a_max": arch_code_a_max_d,
                            "arch_code_c": arch_code_c_d,
                            "iter_num": idx_iter,
                            "epochs": epoch + 1,
                            "best_dsc": best_metric,
                            "best_path": best_metric_iterations,
                        },
                        os.path.join(args.output_root,
                                     "search_code_" + str(idx_iter) + ".pth"),
                    )
                    print("saved new best metric model")

                    dict_file = {}
                    dict_file["best_avg_dice_score"] = float(best_metric)
                    dict_file["best_avg_dice_score_epoch"] = int(
                        best_metric_epoch)
                    dict_file["best_avg_dice_score_iteration"] = int(idx_iter)
                    with open(os.path.join(args.output_root, "progress.yaml"),
                              "w") as out_file:
                        documents = yaml.dump(dict_file, stream=out_file)

                    print(
                        "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}"
                        .format(epoch + 1, avg_metric, best_metric,
                                best_metric_epoch))

                    current_time = time.time()
                    elapsed_time = (current_time - start_time) / 60.0
                    with open(
                            os.path.join(args.output_root,
                                         "accuracy_history.csv"), "a") as f:
                        f.write(
                            "{0:d}\t{1:.5f}\t{2:.5f}\t{3:.5f}\t{4:.1f}\t{5:d}\n"
                            .format(epoch + 1, avg_metric, loss_torch_epoch,
                                    lr, elapsed_time, idx_iter))

                dist.barrier()

            torch.cuda.empty_cache()

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )

    if dist.get_rank() == 0:
        writer.close()

    dist.destroy_process_group()

    return
Exemple #18
0
 def test_type(self, input_param, input_data, expected_type):
     result = CastToTyped(**input_param)(input_data)
     for k, v in result.items():
         self.assertEqual(v.dtype, expected_type[k])
Exemple #19
0
    def test_invert(self):
        set_determinism(seed=0)
        im_fname, seg_fname = [
            make_nifti_image(i)
            for i in create_test_image_3d(101, 100, 107, noise_max=100)
        ]
        transform = Compose([
            LoadImaged(KEYS),
            AddChanneld(KEYS),
            Orientationd(KEYS, "RPS"),
            Spacingd(KEYS,
                     pixdim=(1.2, 1.01, 0.9),
                     mode=["bilinear", "nearest"],
                     dtype=np.float32),
            ScaleIntensityd("image", minv=1, maxv=10),
            RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(KEYS, prob=0.5),
            RandRotate90d(KEYS, spatial_axes=(1, 2)),
            RandZoomd(KEYS,
                      prob=0.5,
                      min_zoom=0.5,
                      max_zoom=1.1,
                      keep_size=True),
            RandRotated(KEYS,
                        prob=0.5,
                        range_x=np.pi,
                        mode="bilinear",
                        align_corners=True),
            RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"),
            ResizeWithPadOrCropd(KEYS, 100),
            ToTensord(KEYS),
            CastToTyped(KEYS, dtype=torch.uint8),
        ])
        data = [{"image": im_fname, "label": seg_fname} for _ in range(12)]

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available(
        ) else 2

        dataset = CacheDataset(data, transform=transform, progress=False)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=5)

        # set up engine
        def _train_func(engine, batch):
            self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100))
            engine.state.output = batch
            engine.fire_event(IterationEvents.MODEL_COMPLETED)
            return engine.state.output

        engine = Engine(_train_func)
        engine.register_events(*IterationEvents)

        # set up testing handler
        TransformInverter(
            transform=transform,
            loader=loader,
            output_keys=["image", "label"],
            batch_keys="label",
            nearest_interp=True,
            num_workers=0
            if sys.platform == "darwin" or torch.cuda.is_available() else 2,
        ).attach(engine)

        engine.run(loader, max_epochs=1)
        set_determinism(seed=None)
        self.assertTupleEqual(engine.state.output["image"].shape,
                              (2, 1, 100, 100, 100))
        self.assertTupleEqual(engine.state.output["label"].shape,
                              (2, 1, 100, 100, 100))
        for i in engine.state.output["image_inverted"] + engine.state.output[
                "label_inverted"]:
            torch.testing.assert_allclose(
                i.to(torch.uint8).to(torch.float), i.to(torch.float))
            self.assertTupleEqual(i.shape, (1, 100, 101, 107))
        # check labels match
        reverted = engine.state.output["label_inverted"][-1].detach().cpu(
        ).numpy()[0].astype(np.int32)
        original = LoadImaged(KEYS)(data[-1])["label"]
        n_good = np.sum(np.isclose(reverted, original, atol=1e-3))
        reverted_name = engine.state.output["label_meta_dict"][
            "filename_or_obj"][-1]
        original_name = data[-1]["label"]
        self.assertEqual(reverted_name, original_name)
        print("invert diff", reverted.size - n_good)
        self.assertTrue((reverted.size - n_good) in (25300, 1812),
                        "diff. in two possible values")