Ejemplo n.º 1
0
 def test_exceptions(self):
     with self.assertRaises(ValueError):  # no meta
         EnsureChannelFirstd("img")({
             "img": np.zeros((1, 2, 3)),
             "img_meta_dict": None
         })
     with self.assertRaises(ValueError):  # no meta channel
         EnsureChannelFirstd("img")({
             "img": np.zeros((1, 2, 3)),
             "img_meta_dict": {
                 "original_channel_dim": None
             }
         })
     EnsureChannelFirstd("img", strict_check=False)({
         "img":
         np.zeros((1, 2, 3)),
         "img_meta_dict":
         None
     })
     EnsureChannelFirstd("img", strict_check=False)({
         "img":
         np.zeros((1, 2, 3)),
         "img_meta_dict": {
             "original_channel_dim": None
         }
     })
Ejemplo n.º 2
0
 def test_exceptions(self):
     im = torch.zeros((1, 2, 3))
     with self.assertRaises(ValueError):  # no meta
         EnsureChannelFirstd("img")({"img": im})
     with self.assertRaises(ValueError):  # no meta channel
         EnsureChannelFirstd("img")({
             "img":
             MetaTensor(im, meta={"original_channel_dim": None})
         })
     EnsureChannelFirstd("img", strict_check=False)({"img": im})
     EnsureChannelFirstd("img", strict_check=False)({
         "img":
         MetaTensor(im, meta={"original_channel_dim": None})
     })
Ejemplo n.º 3
0
 def pre_transforms(self, data):
     return [
         LoadImaged(keys=["image", "label"]),
         EnsureChannelFirstd(keys=["image", "label"]),
         AddBackgroundScribblesFromROId(
             scribbles="label",
             scribbles_bg_label=self.scribbles_bg_label,
             scribbles_fg_label=self.scribbles_fg_label,
         ),
         # at the moment optimisers are bottleneck taking a long time,
         # therefore scaling non-isotropic with big spacing
         Spacingd(keys=["image", "label"], pixdim=self.pix_dim, mode=["bilinear", "nearest"]),
         Orientationd(keys=["image", "label"], axcodes="RAS"),
         ScaleIntensityRanged(
             keys="image",
             a_min=self.intensity_range[0],
             a_max=self.intensity_range[1],
             b_min=self.intensity_range[2],
             b_max=self.intensity_range[3],
             clip=self.intensity_range[4],
         ),
         MakeLikelihoodFromScribblesHistogramd(
             image="image",
             scribbles="label",
             post_proc_label="prob",
             scribbles_bg_label=self.scribbles_bg_label,
             scribbles_fg_label=self.scribbles_fg_label,
             normalise=True,
         ),
     ]
Ejemplo n.º 4
0
    def pre_transforms(self, data=None):
        t = [
            LoadImaged(keys="image", reader="ITKReader"),
            EnsureChannelFirstd(keys="image"),
            Orientationd(keys="image", axcodes="RAS"),
            ScaleIntensityRanged(keys="image", a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
        ]
        if self.type == InferType.DEEPEDIT:
            t.extend(
                [
                    AddGuidanceFromPointsCustomd(ref_image="image", guidance="guidance", label_names=self.labels),
                    Resized(keys="image", spatial_size=self.spatial_size, mode="area"),
                    ResizeGuidanceMultipleLabelCustomd(guidance="guidance", ref_image="image"),
                    AddGuidanceSignalCustomd(
                        keys="image", guidance="guidance", number_intensity_ch=self.number_intensity_ch
                    ),
                ]
            )
        else:
            t.extend(
                [
                    Resized(keys="image", spatial_size=self.spatial_size, mode="area"),
                    DiscardAddGuidanced(
                        keys="image", label_names=self.labels, number_intensity_ch=self.number_intensity_ch
                    ),
                ]
            )

        t.append(EnsureTyped(keys="image", device=data.get("device") if data else None))
        return t
Ejemplo n.º 5
0
 def test_correct(self):
     with tempfile.TemporaryDirectory() as temp_dir:
         transforms = Compose([
             LoadImaged(("im1", "im2")),
             EnsureChannelFirstd(("im1", "im2")),
             CopyItemsd(("im2", "im2_meta_dict"),
                        names=("im3", "im3_meta_dict")),
             ResampleToMatchd("im3", "im1_meta_dict"),
             Lambda(update_fname),
             SaveImaged("im3",
                        output_dir=temp_dir,
                        output_postfix="",
                        separate_folder=False),
         ])
         data = transforms({"im1": self.fnames[0], "im2": self.fnames[1]})
         # check that output sizes match
         assert_allclose(data["im1"].shape, data["im3"].shape)
         # and that the meta data has been updated accordingly
         assert_allclose(data["im3"].shape[1:],
                         data["im3_meta_dict"]["spatial_shape"],
                         type_test=False)
         assert_allclose(data["im3_meta_dict"]["affine"],
                         data["im1_meta_dict"]["affine"])
         # check we're different from the original
         self.assertTrue(
             any(i != j
                 for i, j in zip(data["im3"].shape, data["im2"].shape)))
         self.assertTrue(
             any(i != j
                 for i, j in zip(data["im3_meta_dict"]["affine"].flatten(
                 ), data["im2_meta_dict"]["affine"].flatten())))
         # test the inverse
         data = Invertd("im3", transforms, "im3")(data)
         assert_allclose(data["im2"].shape, data["im3"].shape)
Ejemplo n.º 6
0
 def test_correct(self):
     transforms = Compose([
         LoadImaged(("im1", "im2")),
         EnsureChannelFirstd(("im1", "im2")),
         CopyItemsd(("im2"), names=("im3")),
         ResampleToMatchd("im3", "im1"),
         Lambda(update_fname),
         SaveImaged("im3",
                    output_dir=self.tmpdir,
                    output_postfix="",
                    separate_folder=False,
                    resample=False),
     ])
     data = transforms({"im1": self.fnames[0], "im2": self.fnames[1]})
     # check that output sizes match
     assert_allclose(data["im1"].shape, data["im3"].shape)
     # and that the meta data has been updated accordingly
     assert_allclose(data["im3"].affine, data["im1"].affine)
     # check we're different from the original
     self.assertTrue(
         any(i != j for i, j in zip(data["im3"].shape, data["im2"].shape)))
     self.assertTrue(
         any(i != j for i, j in zip(data["im3"].affine.flatten(),
                                    data["im2"].affine.flatten())))
     # test the inverse
     data = Invertd("im3", transforms)(data)
     assert_allclose(data["im2"].shape, data["im3"].shape)
Ejemplo n.º 7
0
 def test_linear_consistent_dict(self, xform_cls, input_dict, atol):
     """xform cls testing itk consistency"""
     img = LoadImaged(keys, image_only=True, simple_keys=True)({
         keys[0]:
         FILE_PATH,
         keys[1]:
         FILE_PATH_1
     })
     img = EnsureChannelFirstd(keys)(img)
     ref_1 = {k: _create_itk_obj(img[k][0], img[k].affine) for k in keys}
     output = self.run_transform(img, xform_cls, input_dict)
     ref_2 = {
         k: _create_itk_obj(output[k][0], output[k].affine)
         for k in keys
     }
     expected = {k: _resample_to_affine(ref_1[k], ref_2[k]) for k in keys}
     # compare ref_2 and expected results from itk
     diff = {
         k: np.abs(
             itk.GetArrayFromImage(ref_2[k]) -
             itk.GetArrayFromImage(expected[k]))
         for k in keys
     }
     avg_diff = {k: np.mean(diff[k]) for k in keys}
     for k in keys:
         self.assertTrue(avg_diff[k] < atol,
                         f"{xform_cls} avg_diff: {avg_diff}, tol: {atol}")
Ejemplo n.º 8
0
 def val_pre_transforms(self, context: Context):
     return [
         LoadImaged(keys=("image", "label"), reader="ITKReader"),
         NormalizeLabelsInDatasetd(keys="label", label_names=self._labels),
         EnsureChannelFirstd(keys=("image", "label")),
         Orientationd(keys=["image", "label"], axcodes="RAS"),
         # This transform may not work well for MR images
         ScaleIntensityRanged(keys=("image"),
                              a_min=-175,
                              a_max=250,
                              b_min=0.0,
                              b_max=1.0,
                              clip=True),
         Resized(keys=("image", "label"),
                 spatial_size=self.spatial_size,
                 mode=("area", "nearest")),
         # Transforms for click simulation
         FindAllValidSlicesMissingLabelsd(keys="label", sids="sids"),
         AddInitialSeedPointMissingLabelsd(keys="label",
                                           guidance="guidance",
                                           sids="sids"),
         AddGuidanceSignalCustomd(
             keys="image",
             guidance="guidance",
             number_intensity_ch=self.number_intensity_ch),
         #
         ToTensord(keys=("image", "label")),
         SelectItemsd(keys=("image", "label", "guidance", "label_names")),
     ]
def run_inference_test(root_dir, device="cuda:0"):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            EnsureChannelFirstd(keys=["img", "seg"]),
            # resampling with align_corners=True or dtype=float64 will generate
            # slight different results between PyTorch 1.5 an 1.6
            Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
            ScaleIntensityd(keys="img"),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inference need to input 1 image in every iteration
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)

    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    model.load_state_dict(torch.load(model_filename))
    with eval_mode(model):
        # resampling with align_corners=True or dtype=float64 will generate
        # slight different results between PyTorch 1.5 an 1.6
        saver = SaveImage(
            output_dir=os.path.join(root_dir, "output"),
            dtype=np.float32,
            output_ext=".nii.gz",
            output_postfix="seg",
            mode="bilinear",
        )
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            sw_batch_size, roi_size = 4, (96, 96, 96)
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            # decollate prediction into a list
            val_outputs = [val_post_tran(i) for i in decollate_batch(val_outputs)]
            val_meta = decollate_batch(val_data[PostFix.meta("img")])
            # compute metrics
            dice_metric(y_pred=val_outputs, y=val_labels)
            for img, meta in zip(val_outputs, val_meta):  # save a decollated batch of files
                saver(img, meta)

    return dice_metric.aggregate().item()
Ejemplo n.º 10
0
 def test_load_png(self):
     spatial_size = (256, 256, 3)
     test_image = np.random.randint(0, 256, size=spatial_size)
     with tempfile.TemporaryDirectory() as tempdir:
         filename = os.path.join(tempdir, "test_image.png")
         Image.fromarray(test_image.astype("uint8")).save(filename)
         result = LoadImaged(keys="img")({"img": filename})
         result = EnsureChannelFirstd(keys="img")(result)
         self.assertEqual(result["img"].shape[0], 3)
Ejemplo n.º 11
0
def get_image_transforms():
    itk_reader = monai.data.ITKReader()
    # Define transforms for image
    image_transforms = Compose([
        LoadImaged(keys=['img'], reader=itk_reader),
        EnsureChannelFirstd(keys=['img']),
        ScaleIntensityd(keys=['img']),
        ToTensord(keys=['img']),
    ])
    return image_transforms
Ejemplo n.º 12
0
 def pre_transforms(self, data=None) -> Sequence[Callable]:
     return [
         LoadImaged(keys="image", reader="ITKReader"),
         EnsureChannelFirstd(keys="image"),
         Spacingd(keys="image", pixdim=self.target_spacing),
         ScaleIntensityRanged(keys="image",
                              a_min=-175,
                              a_max=250,
                              b_min=0.0,
                              b_max=1.0,
                              clip=True),
         EnsureTyped(keys="image"),
     ]
Ejemplo n.º 13
0
    def test_load_nifti(self, input_param, filenames, original_channel_dim):
        if original_channel_dim is None:
            test_image = np.random.rand(128, 128, 128)
        elif original_channel_dim == -1:
            test_image = np.random.rand(128, 128, 128, 1)

        with tempfile.TemporaryDirectory() as tempdir:
            for i, name in enumerate(filenames):
                filenames[i] = os.path.join(tempdir, name)
                nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i])
            result = LoadImaged(**input_param)({"img": filenames})
            result = EnsureChannelFirstd(**input_param)(result)
            self.assertEqual(result["img"].shape[0], len(filenames))
Ejemplo n.º 14
0
 def val_pre_transforms(self, context: Context):
     return [
         LoadImaged(keys=("image", "label"), reader="ITKReader"),
         EnsureChannelFirstd(keys=("image", "label")),
         Spacingd(keys=("image", "label"),
                  pixdim=self.target_spacing,
                  mode=("bilinear", "nearest")),
         ScaleIntensityRanged(keys="image",
                              a_min=-175,
                              a_max=250,
                              b_min=0.0,
                              b_max=1.0,
                              clip=True),
         EnsureTyped(keys=("image", "label")),
         SelectItemsd(keys=("image", "label")),
     ]
Ejemplo n.º 15
0
 def test_inverse(self):
     loader = Compose(
         [LoadImaged(("im1", "im2")),
          EnsureChannelFirstd(("im1", "im2"))])
     data = loader({"im1": self.fnames[0], "im2": self.fnames[1]})
     tr = ResampleToMatch()
     im_mod = tr(data["im2"], data["im1"])
     self.assertNotEqual(im_mod.shape, data["im2"].shape)
     self.assertGreater(
         ((im_mod.affine - data["im2"].affine)**2).sum()**0.5, 1e-2)
     # inverse
     im_mod2 = tr.inverse(im_mod)
     self.assertEqual(im_mod2.shape, data["im2"].shape)
     self.assertLess(((im_mod2.affine - data["im2"].affine)**2).sum()**0.5,
                     1e-2)
     self.assertEqual(im_mod2.applied_operations, [])
Ejemplo n.º 16
0
    def test_correct(self, reader, writer):
        loader = Compose([
            LoadImaged(("im1", "im2"), reader=reader),
            EnsureChannelFirstd(("im1", "im2"))
        ])
        data = loader({"im1": self.fnames[0], "im2": self.fnames[1]})

        with self.assertRaises(ValueError):
            ResampleToMatch(mode=None)(img=data["im2"], img_dst=data["im1"])
        im_mod = ResampleToMatch()(data["im2"], data["im1"])
        saver = SaveImaged("im3",
                           output_dir=self.tmpdir,
                           output_postfix="",
                           separate_folder=False,
                           writer=writer,
                           resample=False)
        im_mod.meta["filename_or_obj"] = get_rand_fname()
        saver({"im3": im_mod})

        saved = nib.load(
            os.path.join(self.tmpdir, im_mod.meta["filename_or_obj"]))
        assert_allclose(data["im1"].shape[1:], saved.shape)
        assert_allclose(saved.header["dim"][:4], np.array([3, 384, 384, 19]))
Ejemplo n.º 17
0
 def train_pre_transforms(self, context: Context):
     return [
         LoadImaged(keys=("image", "label"), reader="ITKReader"),
         NormalizeLabelsInDatasetd(
             keys="label",
             label_names=self._labels),  # Specially for missing labels
         EnsureChannelFirstd(keys=("image", "label")),
         Spacingd(keys=("image", "label"),
                  pixdim=self.target_spacing,
                  mode=("bilinear", "nearest")),
         CropForegroundd(keys=("image", "label"), source_key="image"),
         SpatialPadd(keys=("image", "label"),
                     spatial_size=self.spatial_size),
         ScaleIntensityRanged(keys="image",
                              a_min=-175,
                              a_max=250,
                              b_min=0.0,
                              b_max=1.0,
                              clip=True),
         RandCropByPosNegLabeld(
             keys=("image", "label"),
             label_key="label",
             spatial_size=self.spatial_size,
             pos=1,
             neg=1,
             num_samples=self.num_samples,
             image_key="image",
             image_threshold=0,
         ),
         EnsureTyped(keys=("image", "label"), device=context.device),
         RandFlipd(keys=("image", "label"), spatial_axis=[0], prob=0.10),
         RandFlipd(keys=("image", "label"), spatial_axis=[1], prob=0.10),
         RandFlipd(keys=("image", "label"), spatial_axis=[2], prob=0.10),
         RandRotate90d(keys=("image", "label"), prob=0.10, max_k=3),
         RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
         SelectItemsd(keys=("image", "label")),
     ]
Ejemplo n.º 18
0
def get_task_transforms(mode, task_id, pos_sample_num, neg_sample_num,
                        num_samples):
    if mode != "test":
        keys = ["image", "label"]
    else:
        keys = ["image"]

    load_transforms = [
        LoadImaged(keys=keys),
        EnsureChannelFirstd(keys=keys),
    ]
    # 2. sampling
    sample_transforms = [
        PreprocessAnisotropic(
            keys=keys,
            clip_values=clip_values[task_id],
            pixdim=spacing[task_id],
            normalize_values=normalize_values[task_id],
            model_mode=mode,
        ),
    ]
    # 3. spatial transforms
    if mode == "train":
        other_transforms = [
            SpatialPadd(keys=["image", "label"],
                        spatial_size=patch_size[task_id]),
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=patch_size[task_id],
                pos=pos_sample_num,
                neg=neg_sample_num,
                num_samples=num_samples,
                image_key="image",
                image_threshold=0,
            ),
            RandZoomd(
                keys=["image", "label"],
                min_zoom=0.9,
                max_zoom=1.2,
                mode=("trilinear", "nearest"),
                align_corners=(True, None),
                prob=0.15,
            ),
            RandGaussianNoised(keys=["image"], std=0.01, prob=0.15),
            RandGaussianSmoothd(
                keys=["image"],
                sigma_x=(0.5, 1.15),
                sigma_y=(0.5, 1.15),
                sigma_z=(0.5, 1.15),
                prob=0.15,
            ),
            RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15),
            RandFlipd(["image", "label"], spatial_axis=[0], prob=0.5),
            RandFlipd(["image", "label"], spatial_axis=[1], prob=0.5),
            RandFlipd(["image", "label"], spatial_axis=[2], prob=0.5),
            CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)),
            EnsureTyped(keys=["image", "label"]),
        ]
    elif mode == "validation":
        other_transforms = [
            CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)),
            EnsureTyped(keys=["image", "label"]),
        ]
    else:
        other_transforms = [
            CastToTyped(keys=["image"], dtype=(np.float32)),
            EnsureTyped(keys=["image"]),
        ]

    all_transforms = load_transforms + sample_transforms + other_transforms
    return Compose(all_transforms)
Ejemplo n.º 19
0
    def test_invert(self):
        set_determinism(seed=0)
        im_fname, seg_fname = (
            make_nifti_image(i)
            for i in create_test_image_3d(101, 100, 107, noise_max=100))
        transform = Compose([
            LoadImaged(KEYS, image_only=True),
            EnsureChannelFirstd(KEYS),
            Orientationd(KEYS, "RPS"),
            Spacingd(KEYS,
                     pixdim=(1.2, 1.01, 0.9),
                     mode=["bilinear", "nearest"],
                     dtype=np.float32),
            ScaleIntensityd("image", minv=1, maxv=10),
            RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(KEYS, prob=0.5),
            RandRotate90d(KEYS, prob=0, spatial_axes=(1, 2)),
            RandZoomd(KEYS,
                      prob=0.5,
                      min_zoom=0.5,
                      max_zoom=1.1,
                      keep_size=True),
            RandRotated(KEYS,
                        prob=0.5,
                        range_x=np.pi,
                        mode="bilinear",
                        align_corners=True,
                        dtype=np.float64),
            RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"),
            ResizeWithPadOrCropd(KEYS, 100),
            CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]),
            CopyItemsd("label",
                       times=2,
                       names=["label_inverted", "label_inverted1"]),
            CopyItemsd("image",
                       times=2,
                       names=["image_inverted", "image_inverted1"]),
        ])
        data = [{"image": im_fname, "label": seg_fname} for _ in range(12)]

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available(
        ) else 2

        dataset = Dataset(data, transform=transform)
        transform.inverse(dataset[0])
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=1)
        inverter = Invertd(
            # `image` was not copied, invert the original value directly
            keys=["image_inverted", "label_inverted"],
            transform=transform,
            orig_keys=["label", "label"],
            nearest_interp=True,
            device="cpu",
        )

        inverter_1 = Invertd(
            # `image` was not copied, invert the original value directly
            keys=["image_inverted1", "label_inverted1"],
            transform=transform,
            orig_keys=["image", "image"],
            nearest_interp=[True, False],
            device="cpu",
        )

        expected_keys = [
            "image", "image_inverted", "image_inverted1", "label",
            "label_inverted", "label_inverted1"
        ]
        # execute 1 epoch
        for d in loader:
            d = decollate_batch(d)
            for item in d:
                item = inverter(item)
                item = inverter_1(item)

                self.assertListEqual(sorted(item), expected_keys)
                self.assertTupleEqual(item["image"].shape[1:], (100, 100, 100))
                self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100))
                # check the nearest interpolation mode
                i = item["image_inverted"]
                torch.testing.assert_allclose(
                    i.to(torch.uint8).to(torch.float), i.to(torch.float))
                self.assertTupleEqual(i.shape[1:], (100, 101, 107))
                i = item["label_inverted"]
                torch.testing.assert_allclose(
                    i.to(torch.uint8).to(torch.float), i.to(torch.float))
                self.assertTupleEqual(i.shape[1:], (100, 101, 107))

                # check the case that different items use different interpolation mode to invert transforms
                d = item["image_inverted1"]
                # if the interpolation mode is nearest, accumulated diff should be smaller than 1
                self.assertLess(
                    torch.sum(
                        d.to(torch.float) -
                        d.to(torch.uint8).to(torch.float)).item(), 1.0)
                self.assertTupleEqual(d.shape, (1, 100, 101, 107))

                d = item["label_inverted1"]
                # if the interpolation mode is not nearest, accumulated diff should be greater than 10000
                self.assertGreater(
                    torch.sum(
                        d.to(torch.float) -
                        d.to(torch.uint8).to(torch.float)).item(), 10000.0)
                self.assertTupleEqual(d.shape, (1, 100, 101, 107))

        # check labels match
        reverted = item["label_inverted"].detach().cpu().numpy().astype(
            np.int32)
        original = LoadImaged(KEYS, image_only=True)(data[-1])["label"]
        n_good = np.sum(np.isclose(reverted, original, atol=1e-3))
        reverted_name = item["label_inverted"].meta["filename_or_obj"]
        original_name = data[-1]["label"]
        self.assertEqual(reverted_name, original_name)
        print("invert diff", reverted.size - n_good)
        # 25300: 2 workers (cpu, non-macos)
        # 1812: 0 workers (gpu or macos)
        # 1821: windows torch 1.10.0
        self.assertTrue((reverted.size - n_good) < 40000,
                        f"diff.  {reverted.size - n_good}")

        set_determinism(seed=None)
Ejemplo n.º 20
0
    def configure(self):
        self.set_device()
        network = UNet(
            dimensions=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        ).to(self.device)
        if self.multi_gpu:
            network = DistributedDataParallel(
                module=network,
                device_ids=[self.device],
                find_unused_parameters=False,
            )

        train_transforms = Compose([
            LoadImaged(keys=("image", "label")),
            EnsureChannelFirstd(keys=("image", "label")),
            Spacingd(keys=("image", "label"),
                     pixdim=[1.0, 1.0, 1.0],
                     mode=["bilinear", "nearest"]),
            ScaleIntensityRanged(
                keys="image",
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=("image", "label"), source_key="image"),
            RandCropByPosNegLabeld(
                keys=("image", "label"),
                label_key="label",
                spatial_size=(96, 96, 96),
                pos=1,
                neg=1,
                num_samples=4,
                image_key="image",
                image_threshold=0,
            ),
            RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
            ToTensord(keys=("image", "label")),
        ])
        train_datalist = load_decathlon_datalist(self.data_list_file_path,
                                                 True, "training")
        if self.multi_gpu:
            train_datalist = partition_dataset(
                data=train_datalist,
                shuffle=True,
                num_partitions=dist.get_world_size(),
                even_divisible=True,
            )[dist.get_rank()]
        train_ds = CacheDataset(
            data=train_datalist,
            transform=train_transforms,
            cache_num=32,
            cache_rate=1.0,
            num_workers=4,
        )
        train_data_loader = DataLoader(
            train_ds,
            batch_size=2,
            shuffle=True,
            num_workers=4,
        )
        val_transforms = Compose([
            LoadImaged(keys=("image", "label")),
            EnsureChannelFirstd(keys=("image", "label")),
            ScaleIntensityRanged(
                keys="image",
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=("image", "label"), source_key="image"),
            ToTensord(keys=("image", "label")),
        ])

        val_datalist = load_decathlon_datalist(self.data_list_file_path, True,
                                               "validation")
        val_ds = CacheDataset(val_datalist, val_transforms, 9, 0.0, 4)
        val_data_loader = DataLoader(
            val_ds,
            batch_size=1,
            shuffle=False,
            num_workers=4,
        )
        post_transform = Compose([
            Activationsd(keys="pred", softmax=True),
            AsDiscreted(
                keys=["pred", "label"],
                argmax=[True, False],
                to_onehot=True,
                n_classes=2,
            ),
        ])
        # metric
        key_val_metric = {
            "val_mean_dice":
            MeanDice(
                include_background=False,
                output_transform=lambda x: (x["pred"], x["label"]),
                device=self.device,
            )
        }
        val_handlers = [
            StatsHandler(output_transform=lambda x: None),
            CheckpointSaver(
                save_dir=self.ckpt_dir,
                save_dict={"model": network},
                save_key_metric=True,
            ),
            TensorBoardStatsHandler(log_dir=self.ckpt_dir,
                                    output_transform=lambda x: None),
        ]
        self.eval_engine = SupervisedEvaluator(
            device=self.device,
            val_data_loader=val_data_loader,
            network=network,
            inferer=SlidingWindowInferer(
                roi_size=[160, 160, 160],
                sw_batch_size=4,
                overlap=0.5,
            ),
            post_transform=post_transform,
            key_val_metric=key_val_metric,
            val_handlers=val_handlers,
            amp=self.amp,
        )

        optimizer = torch.optim.Adam(network.parameters(), self.learning_rate)
        loss_function = DiceLoss(to_onehot_y=True, softmax=True)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       step_size=5000,
                                                       gamma=0.1)
        train_handlers = [
            LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
            ValidationHandler(validator=self.eval_engine,
                              interval=self.val_interval,
                              epoch_level=True),
            StatsHandler(tag_name="train_loss",
                         output_transform=lambda x: x["loss"]),
            TensorBoardStatsHandler(
                log_dir=self.ckpt_dir,
                tag_name="train_loss",
                output_transform=lambda x: x["loss"],
            ),
        ]

        self.train_engine = SupervisedTrainer(
            device=self.device,
            max_epochs=self.max_epochs,
            train_data_loader=train_data_loader,
            network=network,
            optimizer=optimizer,
            loss_function=loss_function,
            inferer=SimpleInferer(),
            post_transform=post_transform,
            key_train_metric=None,
            train_handlers=train_handlers,
            amp=self.amp,
        )

        if self.local_rank > 0:
            self.train_engine.logger.setLevel(logging.WARNING)
            self.eval_engine.logger.setLevel(logging.WARNING)
    def test_train_timing(self):
        images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz")))
        segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz")))
        train_files = [{
            "image": img,
            "label": seg
        } for img, seg in zip(images[:32], segs[:32])]
        val_files = [{
            "image": img,
            "label": seg
        } for img, seg in zip(images[-9:], segs[-9:])]

        device = torch.device("cuda:0")
        # define transforms for train and validation
        train_transforms = Compose([
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Spacingd(keys=["image", "label"],
                     pixdim=(1.0, 1.0, 1.0),
                     mode=("bilinear", "nearest")),
            ScaleIntensityd(keys="image"),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            # pre-compute foreground and background indexes
            # and cache them to accelerate training
            FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"),
            # change to execute transforms with Tensor data
            EnsureTyped(keys=["image", "label"]),
            # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch
            ToDeviced(keys=["image", "label"], device=device),
            # randomly crop out patch samples from big
            # image based on pos / neg ratio
            # the image centers of negative samples
            # must be in valid image area
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=(64, 64, 64),
                pos=1,
                neg=1,
                num_samples=4,
                fg_indices_key="label_fg",
                bg_indices_key="label_bg",
            ),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(keys=["image", "label"], prob=0.5),
            RandRotate90d(keys=["image", "label"],
                          prob=0.5,
                          spatial_axes=(1, 2)),
            RandZoomd(keys=["image", "label"],
                      prob=0.5,
                      min_zoom=0.8,
                      max_zoom=1.2,
                      keep_size=True),
            RandRotated(
                keys=["image", "label"],
                prob=0.5,
                range_x=np.pi / 4,
                mode=("bilinear", "nearest"),
                align_corners=True,
                dtype=np.float64,
            ),
            RandAffined(keys=["image", "label"],
                        prob=0.5,
                        rotate_range=np.pi / 2,
                        mode=("bilinear", "nearest")),
            RandGaussianNoised(keys="image", prob=0.5),
            RandStdShiftIntensityd(keys="image",
                                   prob=0.5,
                                   factors=0.05,
                                   nonzero=True),
        ])

        val_transforms = Compose([
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Spacingd(keys=["image", "label"],
                     pixdim=(1.0, 1.0, 1.0),
                     mode=("bilinear", "nearest")),
            ScaleIntensityd(keys="image"),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            EnsureTyped(keys=["image", "label"]),
            # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch
            ToDeviced(keys=["image", "label"], device=device),
        ])

        max_epochs = 5
        learning_rate = 2e-4
        val_interval = 1  # do validation for every epoch

        # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training
        train_ds = CacheDataset(data=train_files,
                                transform=train_transforms,
                                cache_rate=1.0,
                                num_workers=8)
        val_ds = CacheDataset(data=val_files,
                              transform=val_transforms,
                              cache_rate=1.0,
                              num_workers=5)
        # disable multi-workers because `ThreadDataLoader` works with multi-threads
        train_loader = ThreadDataLoader(train_ds,
                                        num_workers=0,
                                        batch_size=4,
                                        shuffle=True)
        val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1)

        loss_function = DiceCELoss(to_onehot_y=True,
                                   softmax=True,
                                   squared_pred=True,
                                   batch=True)
        model = UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        ).to(device)

        # Novograd paper suggests to use a bigger LR than Adam,
        # because Adam does normalization by element-wise second moments
        optimizer = Novograd(model.parameters(), learning_rate * 10)
        scaler = torch.cuda.amp.GradScaler()

        post_pred = Compose(
            [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
        post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])

        dice_metric = DiceMetric(include_background=True,
                                 reduction="mean",
                                 get_not_nans=False)

        best_metric = -1
        total_start = time.time()
        for epoch in range(max_epochs):
            epoch_start = time.time()
            print("-" * 10)
            print(f"epoch {epoch + 1}/{max_epochs}")
            model.train()
            epoch_loss = 0
            step = 0
            for batch_data in train_loader:
                step_start = time.time()
                step += 1
                optimizer.zero_grad()
                # set AMP for training
                with torch.cuda.amp.autocast():
                    outputs = model(batch_data["image"])
                    loss = loss_function(outputs, batch_data["label"])
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                epoch_loss += loss.item()
                epoch_len = math.ceil(len(train_ds) / train_loader.batch_size)
                print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}"
                      f" step time: {(time.time() - step_start):.4f}")
            epoch_loss /= step
            print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

            if (epoch + 1) % val_interval == 0:
                model.eval()
                with torch.no_grad():
                    for val_data in val_loader:
                        roi_size = (96, 96, 96)
                        sw_batch_size = 4
                        # set AMP for validation
                        with torch.cuda.amp.autocast():
                            val_outputs = sliding_window_inference(
                                val_data["image"], roi_size, sw_batch_size,
                                model)

                        val_outputs = [
                            post_pred(i) for i in decollate_batch(val_outputs)
                        ]
                        val_labels = [
                            post_label(i)
                            for i in decollate_batch(val_data["label"])
                        ]
                        dice_metric(y_pred=val_outputs, y=val_labels)

                    metric = dice_metric.aggregate().item()
                    dice_metric.reset()
                    if metric > best_metric:
                        best_metric = metric
                    print(
                        f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}"
                    )
            print(
                f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}"
            )

        total_time = time.time() - total_start
        print(
            f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}"
        )
        # test expected metrics
        self.assertGreater(best_metric, 0.95)
Ejemplo n.º 22
0
def run_training_test(root_dir,
                      device="cuda:0",
                      cachedataset=0,
                      readers=(None, None)):
    monai.config.print_config()
    images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    train_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadImaged(keys=["img", "seg"], reader=readers[0]),
        EnsureChannelFirstd(keys=["img", "seg"]),
        # resampling with align_corners=True or dtype=float64 will generate
        # slight different results between PyTorch 1.5 an 1.6
        Spacingd(keys=["img", "seg"],
                 pixdim=[1.2, 0.8, 0.7],
                 mode=["bilinear", "nearest"],
                 dtype=np.float32),
        ScaleIntensityd(keys="img"),
        RandCropByPosNegLabeld(keys=["img", "seg"],
                               label_key="seg",
                               spatial_size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=[0, 2]),
        ToTensord(keys=["img", "seg"]),
    ])
    train_transforms.set_random_state(1234)
    val_transforms = Compose([
        LoadImaged(keys=["img", "seg"], reader=readers[1]),
        EnsureChannelFirstd(keys=["img", "seg"]),
        # resampling with align_corners=True or dtype=float64 will generate
        # slight different results between PyTorch 1.5 an 1.6
        Spacingd(keys=["img", "seg"],
                 pixdim=[1.2, 0.8, 0.7],
                 mode=["bilinear", "nearest"],
                 dtype=np.float32),
        ScaleIntensityd(keys="img"),
        ToTensord(keys=["img", "seg"]),
    ])

    # create a training data loader
    if cachedataset == 2:
        train_ds = monai.data.CacheDataset(data=train_files,
                                           transform=train_transforms,
                                           cache_rate=0.8)
    elif cachedataset == 3:
        train_ds = monai.data.LMDBDataset(data=train_files,
                                          transform=train_transforms,
                                          cache_dir=root_dir)
    else:
        train_ds = monai.data.Dataset(data=train_files,
                                      transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds,
                                         batch_size=2,
                                         shuffle=True,
                                         num_workers=4)
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    val_post_tran = Compose([
        ToTensor(),
        Activations(sigmoid=True),
        AsDiscrete(threshold_values=True)
    ])
    dice_metric = DiceMetric(include_background=True,
                             reduction="mean",
                             get_not_nans=False)

    # create UNet, DiceLoss and Adam optimizer
    model = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 5e-4)

    # start a typical PyTorch training
    val_interval = 2
    best_metric, best_metric_epoch = -1, -1
    epoch_loss_values = []
    metric_values = []
    writer = SummaryWriter(log_dir=os.path.join(root_dir, "runs"))
    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    for epoch in range(6):
        print("-" * 10)
        print(f"Epoch {epoch + 1}/{6}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(
                device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss:{loss.item():0.4f}")
            writer.add_scalar("train_loss", loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch +1} average loss:{epoch_loss:0.4f}")

        if (epoch + 1) % val_interval == 0:
            with eval_mode(model):
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data["img"].to(
                        device), val_data["seg"].to(device)
                    sw_batch_size, roi_size = 4, (96, 96, 96)
                    val_outputs = sliding_window_inference(
                        val_images, roi_size, sw_batch_size, model)
                    # decollate prediction into a list and execute post processing for every item
                    val_outputs = [
                        val_post_tran(i) for i in decollate_batch(val_outputs)
                    ]
                    # compute metrics
                    dice_metric(y_pred=val_outputs, y=val_labels)

                metric = dice_metric.aggregate().item()
                dice_metric.reset()
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), model_filename)
                    print("saved new best metric model")
                print(
                    f"current epoch {epoch +1} current mean dice: {metric:0.4f} "
                    f"best mean dice: {best_metric:0.4f} at epoch {best_metric_epoch}"
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="image")
                plot_2d_or_3d_image(val_labels,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="label")
                plot_2d_or_3d_image(val_outputs,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="output")
    print(
        f"train completed, best_metric: {best_metric:0.4f}  at epoch: {best_metric_epoch}"
    )
    writer.close()
    return epoch_loss_values, best_metric, best_metric_epoch
Ejemplo n.º 23
0
out_dir = "./outputs_fast"

train_images = sorted(glob.glob(os.path.join(data_root, "imagesTr", "*.nii.gz")))
train_labels = sorted(glob.glob(os.path.join(data_root, "labelsTr", "*.nii.gz")))
data_dicts = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]

set_determinism(seed=0)

train_transforms = Compose(
    [
        Range("LoadImage")(LoadImaged(keys=["image", "label"])),
        Range()(EnsureChannelFirstd(keys=["image", "label"])),
        Range("Spacing")(
            Spacingd(
                keys=["image", "label"],
                pixdim=(1.5, 1.5, 2.0),
                mode=("bilinear", "nearest"),
            )
        ),
        Range()(Orientationd(keys=["image", "label"], axcodes="RAS")),
        Range()(
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
Ejemplo n.º 24
0
def main(tempdir):
    print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, _ = create_test_image_3d(128,
                                     128,
                                     128,
                                     num_seg_classes=1,
                                     channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    files = [{"img": img} for img in images]

    # define pre transforms
    pre_transforms = Compose([
        LoadImaged(keys="img"),
        EnsureChannelFirstd(keys="img"),
        Orientationd(keys="img", axcodes="RAS"),
        Resized(keys="img",
                spatial_size=(96, 96, 96),
                mode="trilinear",
                align_corners=True),
        ScaleIntensityd(keys="img"),
        ToTensord(keys="img"),
    ])
    # define dataset and dataloader
    dataset = Dataset(data=files, transform=pre_transforms)
    dataloader = DataLoader(dataset, batch_size=2, num_workers=4)
    # define post transforms
    post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
        Invertd(keys="pred",
                transform=pre_transforms,
                loader=dataloader,
                orig_keys="img",
                nearest_interp=True),
        SaveImaged(keys="pred_inverted",
                   output_dir="./output",
                   output_postfix="seg",
                   resample=False),
    ])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    net.load_state_dict(
        torch.load("best_metric_model_segmentation3d_dict.pth"))

    net.eval()
    with torch.no_grad():
        for d in dataloader:
            images = d["img"].to(device)
            # define sliding window size and batch size for windows inference
            d["pred"] = sliding_window_inference(inputs=images,
                                                 roi_size=(96, 96, 96),
                                                 sw_batch_size=4,
                                                 predictor=net)
            # execute post transforms to invert spatial transforms and save to NIfTI files
            post_transforms(d)
Ejemplo n.º 25
0
def compute(args):
    # generate synthetic data for the example
    if args.local_rank == 0 and not os.path.exists(args.dir):
        # create 16 random pred, label paris for evaluation
        print(
            f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # if have multiple nodes, set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(16):
            pred, label = create_test_image_3d(128,
                                               128,
                                               128,
                                               num_seg_classes=1,
                                               channel_dim=-1,
                                               noise_max=0.5)
            n = nib.Nifti1Image(pred, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"pred{i:d}.nii.gz"))
            n = nib.Nifti1Image(label, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"label{i:d}.nii.gz"))

    # initialize the distributed evaluation process, change to NCCL backend if computing on GPU
    dist.init_process_group(backend="gloo", init_method="env://")

    preds = sorted(glob(os.path.join(args.dir, "pred*.nii.gz")))
    labels = sorted(glob(os.path.join(args.dir, "label*.nii.gz")))
    datalist = [{
        "pred": pred,
        "label": label
    } for pred, label in zip(preds, labels)]

    # split data for every subprocess, for example, 16 processes compute in parallel
    data_part = partition_dataset(
        data=datalist,
        num_partitions=dist.get_world_size(),
        shuffle=False,
        even_divisible=False,
    )[dist.get_rank()]

    # define transforms for predictions and labels
    transforms = Compose([
        LoadImaged(keys=["pred", "label"]),
        EnsureChannelFirstd(keys=["pred", "label"]),
        ScaleIntensityd(keys="pred"),
        EnsureTyped(keys=["pred", "label"]),
        AsDiscreted(keys="pred", threshold=0.5),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
    ])
    data_part = [transforms(item) for item in data_part]

    # compute metrics for current process
    metric = DiceMetric(include_background=True,
                        reduction="mean",
                        get_not_nans=False)
    metric(y_pred=[i["pred"] for i in data_part],
           y=[i["label"] for i in data_part])
    filenames = [
        item["pred_meta_dict"]["filename_or_obj"] for item in data_part
    ]
    # all-gather results from all the processes and reduce for final result
    result = metric.aggregate().item()
    filenames = string_list_all_gather(strings=filenames)

    if args.local_rank == 0:
        print("mean dice: ", result)
        # generate metrics reports at: output/mean_dice_raw.csv, output/mean_dice_summary.csv, output/metrics.csv
        write_metrics_reports(
            save_dir="./output",
            images=filenames,
            metrics={"mean_dice": result},
            metric_details={"mean_dice": metric.get_buffer()},
            summary_ops="*",
        )

    metric.reset()

    dist.destroy_process_group()
Ejemplo n.º 26
0
def main():
    parser = argparse.ArgumentParser(description="training")
    parser.add_argument(
        "--checkpoint",
        type=str,
        default=None,
        help="checkpoint full path",
    )
    parser.add_argument(
        "--factor_ram_cost",
        default=0.0,
        type=float,
        help="factor to determine RAM cost in the searched architecture",
    )
    parser.add_argument(
        "--fold",
        action="store",
        required=True,
        help="fold index in N-fold cross-validation",
    )
    parser.add_argument(
        "--json",
        action="store",
        required=True,
        help="full path of .json file",
    )
    parser.add_argument(
        "--json_key",
        action="store",
        required=True,
        help="selected key in .json data list",
    )
    parser.add_argument(
        "--local_rank",
        required=int,
        help="local process rank",
    )
    parser.add_argument(
        "--num_folds",
        action="store",
        required=True,
        help="number of folds in cross-validation",
    )
    parser.add_argument(
        "--output_root",
        action="store",
        required=True,
        help="output root",
    )
    parser.add_argument(
        "--root",
        action="store",
        required=True,
        help="data root",
    )
    args = parser.parse_args()

    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    if not os.path.exists(args.output_root):
        os.makedirs(args.output_root, exist_ok=True)

    amp = True
    determ = True
    factor_ram_cost = args.factor_ram_cost
    fold = int(args.fold)
    input_channels = 1
    learning_rate = 0.025
    learning_rate_arch = 0.001
    learning_rate_milestones = np.array([0.4, 0.8])
    num_images_per_batch = 1
    num_epochs = 1430  # around 20k iteration
    num_epochs_per_validation = 100
    num_epochs_warmup = 715
    num_folds = int(args.num_folds)
    num_patches_per_image = 1
    num_sw_batch_size = 6
    output_classes = 3
    overlap_ratio = 0.625
    patch_size = (96, 96, 96)
    patch_size_valid = (96, 96, 96)
    spacing = [1.0, 1.0, 1.0]

    print("factor_ram_cost", factor_ram_cost)

    # deterministic training
    if determ:
        set_determinism(seed=0)

    # initialize the distributed training process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    # dist.barrier()
    world_size = dist.get_world_size()

    with open(args.json, "r") as f:
        json_data = json.load(f)

    split = len(json_data[args.json_key]) // num_folds
    list_train = json_data[args.json_key][:(
        split * fold)] + json_data[args.json_key][(split * (fold + 1)):]
    list_valid = json_data[args.json_key][(split * fold):(split * (fold + 1))]

    # training data
    files = []
    for _i in range(len(list_train)):
        str_img = os.path.join(args.root, list_train[_i]["image"])
        str_seg = os.path.join(args.root, list_train[_i]["label"])

        if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)):
            continue

        files.append({"image": str_img, "label": str_seg})
    train_files = files

    random.shuffle(train_files)

    train_files_w = train_files[:len(train_files) // 2]
    train_files_w = partition_dataset(data=train_files_w,
                                      shuffle=True,
                                      num_partitions=world_size,
                                      even_divisible=True)[dist.get_rank()]
    print("train_files_w:", len(train_files_w))

    train_files_a = train_files[len(train_files) // 2:]
    train_files_a = partition_dataset(data=train_files_a,
                                      shuffle=True,
                                      num_partitions=world_size,
                                      even_divisible=True)[dist.get_rank()]
    print("train_files_a:", len(train_files_a))

    # validation data
    files = []
    for _i in range(len(list_valid)):
        str_img = os.path.join(args.root, list_valid[_i]["image"])
        str_seg = os.path.join(args.root, list_valid[_i]["label"])

        if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)):
            continue

        files.append({"image": str_img, "label": str_seg})
    val_files = files
    val_files = partition_dataset(data=val_files,
                                  shuffle=False,
                                  num_partitions=world_size,
                                  even_divisible=False)[dist.get_rank()]
    print("val_files:", len(val_files))

    # network architecture
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)

    train_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"],
                 pixdim=spacing,
                 mode=("bilinear", "nearest"),
                 align_corners=(True, True)),
        CastToTyped(keys=["image"], dtype=(torch.float32)),
        ScaleIntensityRanged(keys=["image"],
                             a_min=-87.0,
                             a_max=199.0,
                             b_min=0.0,
                             b_max=1.0,
                             clip=True),
        CastToTyped(keys=["image", "label"], dtype=(np.float16, np.uint8)),
        CopyItemsd(keys=["label"], times=1, names=["label4crop"]),
        Lambdad(
            keys=["label4crop"],
            func=lambda x: np.concatenate(tuple([
                ndimage.binary_dilation(
                    (x == _k).astype(x.dtype), iterations=48).astype(x.dtype)
                for _k in range(output_classes)
            ]),
                                          axis=0),
            overwrite=True,
        ),
        EnsureTyped(keys=["image", "label"]),
        CastToTyped(keys=["image"], dtype=(torch.float32)),
        SpatialPadd(keys=["image", "label", "label4crop"],
                    spatial_size=patch_size,
                    mode=["reflect", "constant", "constant"]),
        RandCropByLabelClassesd(keys=["image", "label"],
                                label_key="label4crop",
                                num_classes=output_classes,
                                ratios=[
                                    1,
                                ] * output_classes,
                                spatial_size=patch_size,
                                num_samples=num_patches_per_image),
        Lambdad(keys=["label4crop"], func=lambda x: 0),
        RandRotated(keys=["image", "label"],
                    range_x=0.3,
                    range_y=0.3,
                    range_z=0.3,
                    mode=["bilinear", "nearest"],
                    prob=0.2),
        RandZoomd(keys=["image", "label"],
                  min_zoom=0.8,
                  max_zoom=1.2,
                  mode=["trilinear", "nearest"],
                  align_corners=[True, None],
                  prob=0.16),
        RandGaussianSmoothd(keys=["image"],
                            sigma_x=(0.5, 1.15),
                            sigma_y=(0.5, 1.15),
                            sigma_z=(0.5, 1.15),
                            prob=0.15),
        RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.5),
        RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
        RandGaussianNoised(keys=["image"], std=0.01, prob=0.15),
        RandFlipd(keys=["image", "label"], spatial_axis=0, prob=0.5),
        RandFlipd(keys=["image", "label"], spatial_axis=1, prob=0.5),
        RandFlipd(keys=["image", "label"], spatial_axis=2, prob=0.5),
        CastToTyped(keys=["image", "label"],
                    dtype=(torch.float32, torch.uint8)),
        ToTensord(keys=["image", "label"]),
    ])

    val_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"],
                 pixdim=spacing,
                 mode=("bilinear", "nearest"),
                 align_corners=(True, True)),
        CastToTyped(keys=["image"], dtype=(torch.float32)),
        ScaleIntensityRanged(keys=["image"],
                             a_min=-87.0,
                             a_max=199.0,
                             b_min=0.0,
                             b_max=1.0,
                             clip=True),
        CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)),
        EnsureTyped(keys=["image", "label"]),
        ToTensord(keys=["image", "label"])
    ])

    train_ds_a = monai.data.CacheDataset(data=train_files_a,
                                         transform=train_transforms,
                                         cache_rate=1.0,
                                         num_workers=8)
    train_ds_w = monai.data.CacheDataset(data=train_files_w,
                                         transform=train_transforms,
                                         cache_rate=1.0,
                                         num_workers=8)
    val_ds = monai.data.CacheDataset(data=val_files,
                                     transform=val_transforms,
                                     cache_rate=1.0,
                                     num_workers=2)

    # monai.data.Dataset can be used as alternatives when debugging or RAM space is limited.
    # train_ds_a = monai.data.Dataset(data=train_files_a, transform=train_transforms)
    # train_ds_w = monai.data.Dataset(data=train_files_w, transform=train_transforms)
    # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)

    train_loader_a = ThreadDataLoader(train_ds_a,
                                      num_workers=0,
                                      batch_size=num_images_per_batch,
                                      shuffle=True)
    train_loader_w = ThreadDataLoader(train_ds_w,
                                      num_workers=0,
                                      batch_size=num_images_per_batch,
                                      shuffle=True)
    val_loader = ThreadDataLoader(val_ds,
                                  num_workers=0,
                                  batch_size=1,
                                  shuffle=False)

    # DataLoader can be used as alternatives when ThreadDataLoader is less efficient.
    # train_loader_a = DataLoader(train_ds_a, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())
    # train_loader_w = DataLoader(train_ds_w, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())
    # val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available())

    dints_space = monai.networks.nets.TopologySearch(
        channel_mul=0.5,
        num_blocks=12,
        num_depths=4,
        use_downsample=True,
        device=device,
    )

    model = monai.networks.nets.DiNTS(
        dints_space=dints_space,
        in_channels=input_channels,
        num_classes=output_classes,
        use_downsample=True,
    )

    model = model.to(device)

    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    post_pred = Compose(
        [EnsureType(),
         AsDiscrete(argmax=True, to_onehot=output_classes)])
    post_label = Compose([EnsureType(), AsDiscrete(to_onehot=output_classes)])

    # loss function
    loss_func = monai.losses.DiceCELoss(
        include_background=False,
        to_onehot_y=True,
        softmax=True,
        squared_pred=True,
        batch=True,
        smooth_nr=0.00001,
        smooth_dr=0.00001,
    )

    # optimizer
    optimizer = torch.optim.SGD(model.weight_parameters(),
                                lr=learning_rate * world_size,
                                momentum=0.9,
                                weight_decay=0.00004)
    arch_optimizer_a = torch.optim.Adam([dints_space.log_alpha_a],
                                        lr=learning_rate_arch * world_size,
                                        betas=(0.5, 0.999),
                                        weight_decay=0.0)
    arch_optimizer_c = torch.optim.Adam([dints_space.log_alpha_c],
                                        lr=learning_rate_arch * world_size,
                                        betas=(0.5, 0.999),
                                        weight_decay=0.0)

    print()

    if torch.cuda.device_count() > 1:
        if dist.get_rank() == 0:
            print("Let's use", torch.cuda.device_count(), "GPUs!")

        model = DistributedDataParallel(model,
                                        device_ids=[device],
                                        find_unused_parameters=True)

    if args.checkpoint != None and os.path.isfile(args.checkpoint):
        print("[info] fine-tuning pre-trained checkpoint {0:s}".format(
            args.checkpoint))
        model.load_state_dict(torch.load(args.checkpoint, map_location=device))
        torch.cuda.empty_cache()
    else:
        print("[info] training from scratch")

    # amp
    if amp:
        from torch.cuda.amp import autocast, GradScaler
        scaler = GradScaler()
        if dist.get_rank() == 0:
            print("[info] amp enabled")

    # start a typical PyTorch training
    val_interval = num_epochs_per_validation
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    idx_iter = 0
    metric_values = list()

    if dist.get_rank() == 0:
        writer = SummaryWriter(
            log_dir=os.path.join(args.output_root, "Events"))

        with open(os.path.join(args.output_root, "accuracy_history.csv"),
                  "a") as f:
            f.write("epoch\tmetric\tloss\tlr\ttime\titer\n")

    dataloader_a_iterator = iter(train_loader_a)

    start_time = time.time()
    for epoch in range(num_epochs):
        decay = 0.5**np.sum([
            (epoch - num_epochs_warmup) /
            (num_epochs - num_epochs_warmup) > learning_rate_milestones
        ])
        lr = learning_rate * decay
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

        if dist.get_rank() == 0:
            print("-" * 10)
            print(f"epoch {epoch + 1}/{num_epochs}")
            print("learning rate is set to {}".format(lr))

        model.train()
        epoch_loss = 0
        loss_torch = torch.zeros(2, dtype=torch.float, device=device)
        epoch_loss_arch = 0
        loss_torch_arch = torch.zeros(2, dtype=torch.float, device=device)
        step = 0

        for batch_data in train_loader_w:
            step += 1
            inputs, labels = batch_data["image"].to(
                device), batch_data["label"].to(device)
            if world_size == 1:
                for _ in model.weight_parameters():
                    _.requires_grad = True
            else:
                for _ in model.module.weight_parameters():
                    _.requires_grad = True
            dints_space.log_alpha_a.requires_grad = False
            dints_space.log_alpha_c.requires_grad = False

            optimizer.zero_grad()

            if amp:
                with autocast():
                    outputs = model(inputs)
                    if output_classes == 2:
                        loss = loss_func(torch.flip(outputs, dims=[1]),
                                         1 - labels)
                    else:
                        loss = loss_func(outputs, labels)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(inputs)
                if output_classes == 2:
                    loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels)
                else:
                    loss = loss_func(outputs, labels)
                loss.backward()
                optimizer.step()

            epoch_loss += loss.item()
            loss_torch[0] += loss.item()
            loss_torch[1] += 1.0
            epoch_len = len(train_loader_w)
            idx_iter += 1

            if dist.get_rank() == 0:
                print("[{0}] ".format(str(datetime.now())[:19]) +
                      f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
                writer.add_scalar("train_loss", loss.item(),
                                  epoch_len * epoch + step)

            if epoch < num_epochs_warmup:
                continue

            try:
                sample_a = next(dataloader_a_iterator)
            except StopIteration:
                dataloader_a_iterator = iter(train_loader_a)
                sample_a = next(dataloader_a_iterator)
            inputs_search, labels_search = sample_a["image"].to(
                device), sample_a["label"].to(device)
            if world_size == 1:
                for _ in model.weight_parameters():
                    _.requires_grad = False
            else:
                for _ in model.module.weight_parameters():
                    _.requires_grad = False
            dints_space.log_alpha_a.requires_grad = True
            dints_space.log_alpha_c.requires_grad = True

            # linear increase topology and RAM loss
            entropy_alpha_c = torch.tensor(0.).to(device)
            entropy_alpha_a = torch.tensor(0.).to(device)
            ram_cost_full = torch.tensor(0.).to(device)
            ram_cost_usage = torch.tensor(0.).to(device)
            ram_cost_loss = torch.tensor(0.).to(device)
            topology_loss = torch.tensor(0.).to(device)

            probs_a, arch_code_prob_a = dints_space.get_prob_a(child=True)
            entropy_alpha_a = -((probs_a) * torch.log(probs_a + 1e-5)).mean()
            entropy_alpha_c = -(F.softmax(dints_space.log_alpha_c, dim=-1) * \
                F.log_softmax(dints_space.log_alpha_c, dim=-1)).mean()
            topology_loss = dints_space.get_topology_entropy(probs_a)

            ram_cost_full = dints_space.get_ram_cost_usage(inputs.shape,
                                                           full=True)
            ram_cost_usage = dints_space.get_ram_cost_usage(inputs.shape)
            ram_cost_loss = torch.abs(factor_ram_cost -
                                      ram_cost_usage / ram_cost_full)

            arch_optimizer_a.zero_grad()
            arch_optimizer_c.zero_grad()

            combination_weights = (epoch - num_epochs_warmup) / (
                num_epochs - num_epochs_warmup)

            if amp:
                with autocast():
                    outputs_search = model(inputs_search)
                    if output_classes == 2:
                        loss = loss_func(torch.flip(outputs_search, dims=[1]),
                                         1 - labels_search)
                    else:
                        loss = loss_func(outputs_search, labels_search)

                    loss += combination_weights * ((entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \
                                                    + 0.001 * topology_loss)

                scaler.scale(loss).backward()
                scaler.step(arch_optimizer_a)
                scaler.step(arch_optimizer_c)
                scaler.update()
            else:
                outputs_search = model(inputs_search)
                if output_classes == 2:
                    loss = loss_func(torch.flip(outputs_search, dims=[1]),
                                     1 - labels_search)
                else:
                    loss = loss_func(outputs_search, labels_search)

                loss += 1.0 * (combination_weights * (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \
                                + 0.001 * topology_loss)

                loss.backward()
                arch_optimizer_a.step()
                arch_optimizer_c.step()

            epoch_loss_arch += loss.item()
            loss_torch_arch[0] += loss.item()
            loss_torch_arch[1] += 1.0

            if dist.get_rank() == 0:
                print(
                    "[{0}] ".format(str(datetime.now())[:19]) +
                    f"{step}/{epoch_len}, train_loss_arch: {loss.item():.4f}")
                writer.add_scalar("train_loss_arch", loss.item(),
                                  epoch_len * epoch + step)

        # synchronizes all processes and reduce results
        dist.barrier()
        dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM)
        loss_torch = loss_torch.tolist()
        loss_torch_arch = loss_torch_arch.tolist()
        if dist.get_rank() == 0:
            loss_torch_epoch = loss_torch[0] / loss_torch[1]
            print(
                f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}"
            )

            if epoch >= num_epochs_warmup:
                loss_torch_arch_epoch = loss_torch_arch[0] / loss_torch_arch[1]
                print(
                    f"epoch {epoch + 1} average arch loss: {loss_torch_arch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}"
                )

        if (epoch + 1) % val_interval == 0:
            torch.cuda.empty_cache()
            model.eval()
            with torch.no_grad():
                metric = torch.zeros((output_classes - 1) * 2,
                                     dtype=torch.float,
                                     device=device)
                metric_sum = 0.0
                metric_count = 0
                metric_mat = []
                val_images = None
                val_labels = None
                val_outputs = None

                _index = 0
                for val_data in val_loader:
                    val_images = val_data["image"].to(device)
                    val_labels = val_data["label"].to(device)

                    roi_size = patch_size_valid
                    sw_batch_size = num_sw_batch_size

                    if amp:
                        with torch.cuda.amp.autocast():
                            pred = sliding_window_inference(
                                val_images,
                                roi_size,
                                sw_batch_size,
                                lambda x: model(x),
                                mode="gaussian",
                                overlap=overlap_ratio,
                            )
                    else:
                        pred = sliding_window_inference(
                            val_images,
                            roi_size,
                            sw_batch_size,
                            lambda x: model(x),
                            mode="gaussian",
                            overlap=overlap_ratio,
                        )
                    val_outputs = pred

                    val_outputs = post_pred(val_outputs[0, ...])
                    val_outputs = val_outputs[None, ...]
                    val_labels = post_label(val_labels[0, ...])
                    val_labels = val_labels[None, ...]

                    value = compute_meandice(y_pred=val_outputs,
                                             y=val_labels,
                                             include_background=False)

                    print(_index + 1, "/", len(val_loader), value)

                    metric_count += len(value)
                    metric_sum += value.sum().item()
                    metric_vals = value.cpu().numpy()
                    if len(metric_mat) == 0:
                        metric_mat = metric_vals
                    else:
                        metric_mat = np.concatenate((metric_mat, metric_vals),
                                                    axis=0)

                    for _c in range(output_classes - 1):
                        val0 = torch.nan_to_num(value[0, _c], nan=0.0)
                        val1 = 1.0 - torch.isnan(value[0, 0]).float()
                        metric[2 * _c] += val0 * val1
                        metric[2 * _c + 1] += val1

                    _index += 1

                # synchronizes all processes and reduce results
                dist.barrier()
                dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM)
                metric = metric.tolist()
                if dist.get_rank() == 0:
                    for _c in range(output_classes - 1):
                        print(
                            "evaluation metric - class {0:d}:".format(_c + 1),
                            metric[2 * _c] / metric[2 * _c + 1])
                    avg_metric = 0
                    for _c in range(output_classes - 1):
                        avg_metric += metric[2 * _c] / metric[2 * _c + 1]
                    avg_metric = avg_metric / float(output_classes - 1)
                    print("avg_metric", avg_metric)

                    if avg_metric > best_metric:
                        best_metric = avg_metric
                        best_metric_epoch = epoch + 1
                        best_metric_iterations = idx_iter

                    node_a_d, arch_code_a_d, arch_code_c_d, arch_code_a_max_d = dints_space.decode(
                    )
                    torch.save(
                        {
                            "node_a": node_a_d,
                            "arch_code_a": arch_code_a_d,
                            "arch_code_a_max": arch_code_a_max_d,
                            "arch_code_c": arch_code_c_d,
                            "iter_num": idx_iter,
                            "epochs": epoch + 1,
                            "best_dsc": best_metric,
                            "best_path": best_metric_iterations,
                        },
                        os.path.join(args.output_root,
                                     "search_code_" + str(idx_iter) + ".pth"),
                    )
                    print("saved new best metric model")

                    dict_file = {}
                    dict_file["best_avg_dice_score"] = float(best_metric)
                    dict_file["best_avg_dice_score_epoch"] = int(
                        best_metric_epoch)
                    dict_file["best_avg_dice_score_iteration"] = int(idx_iter)
                    with open(os.path.join(args.output_root, "progress.yaml"),
                              "w") as out_file:
                        documents = yaml.dump(dict_file, stream=out_file)

                    print(
                        "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}"
                        .format(epoch + 1, avg_metric, best_metric,
                                best_metric_epoch))

                    current_time = time.time()
                    elapsed_time = (current_time - start_time) / 60.0
                    with open(
                            os.path.join(args.output_root,
                                         "accuracy_history.csv"), "a") as f:
                        f.write(
                            "{0:d}\t{1:.5f}\t{2:.5f}\t{3:.5f}\t{4:.1f}\t{5:d}\n"
                            .format(epoch + 1, avg_metric, loss_torch_epoch,
                                    lr, elapsed_time, idx_iter))

                dist.barrier()

            torch.cuda.empty_cache()

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )

    if dist.get_rank() == 0:
        writer.close()

    dist.destroy_process_group()

    return
Ejemplo n.º 27
0
def main_worker(args):
    # disable logging for processes except 0 on every node
    if args.local_rank != 0:
        f = open(os.devnull, "w")
        sys.stdout = sys.stderr = f
    if not os.path.exists(args.dir):
        raise FileNotFoundError(f"missing directory {args.dir}")

    # initialize the distributed training process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)
    # use amp to accelerate training
    scaler = torch.cuda.amp.GradScaler()
    torch.backends.cudnn.benchmark = True

    total_start = time.time()
    train_transforms = Compose([
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        EnsureTyped(keys=["image", "label"]),
        ToDeviced(keys=["image", "label"], device=device),
        RandSpatialCropd(keys=["image", "label"],
                         roi_size=[224, 224, 144],
                         random_size=False),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
    ])

    # create a training data loader
    train_ds = BratsCacheDataset(
        root_dir=args.dir,
        transform=train_transforms,
        section="training",
        num_workers=4,
        cache_rate=args.cache_rate,
        shuffle=True,
    )
    # ThreadDataLoader can be faster if no IO operations when caching all the data in memory
    train_loader = ThreadDataLoader(train_ds,
                                    num_workers=0,
                                    batch_size=args.batch_size,
                                    shuffle=True)

    # validation transforms and dataset
    val_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        EnsureTyped(keys=["image", "label"]),
        ToDeviced(keys=["image", "label"], device=device),
    ])
    val_ds = BratsCacheDataset(
        root_dir=args.dir,
        transform=val_transforms,
        section="validation",
        num_workers=4,
        cache_rate=args.cache_rate,
        shuffle=False,
    )
    # ThreadDataLoader can be faster if no IO operations when caching all the data in memory
    val_loader = ThreadDataLoader(val_ds,
                                  num_workers=0,
                                  batch_size=args.batch_size,
                                  shuffle=False)

    # create network, loss function and optimizer
    if args.network == "SegResNet":
        model = SegResNet(
            blocks_down=[1, 2, 2, 4],
            blocks_up=[1, 1, 1],
            init_filters=16,
            in_channels=4,
            out_channels=3,
            dropout_prob=0.0,
        ).to(device)
    else:
        model = UNet(
            spatial_dims=3,
            in_channels=4,
            out_channels=3,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
        ).to(device)

    loss_function = DiceFocalLoss(
        smooth_nr=1e-5,
        smooth_dr=1e-5,
        squared_pred=True,
        to_onehot_y=False,
        sigmoid=True,
        batch=True,
    )
    optimizer = Novograd(model.parameters(), lr=args.lr)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=args.epochs)
    # wrap the model with DistributedDataParallel module
    model = DistributedDataParallel(model, device_ids=[device])

    dice_metric = DiceMetric(include_background=True, reduction="mean")
    dice_metric_batch = DiceMetric(include_background=True,
                                   reduction="mean_batch")

    post_trans = Compose(
        [EnsureType(),
         Activations(sigmoid=True),
         AsDiscrete(threshold=0.5)])

    # start a typical PyTorch training
    best_metric = -1
    best_metric_epoch = -1
    print(f"time elapsed before training: {time.time() - total_start}")
    train_start = time.time()
    for epoch in range(args.epochs):
        epoch_start = time.time()
        print("-" * 10)
        print(f"epoch {epoch + 1}/{args.epochs}")
        epoch_loss = train(train_loader, model, loss_function, optimizer,
                           lr_scheduler, scaler)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % args.val_interval == 0:
            metric, metric_tc, metric_wt, metric_et = evaluate(
                model, val_loader, dice_metric, dice_metric_batch, post_trans)

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                if dist.get_rank() == 0:
                    torch.save(model.state_dict(), "best_metric_model.pth")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
                f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
            )

        print(
            f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}"
        )

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch},"
        f" total train time: {(time.time() - train_start):.4f}")
    dist.destroy_process_group()
Ejemplo n.º 28
0
def main():

    #TODO Defining file paths & output directory path
    json_Path = os.path.normpath('/scratch/data_2021/tcia_covid19/dataset_split_debug.json')
    data_Root = os.path.normpath('/scratch/data_2021/tcia_covid19')
    logdir_path = os.path.normpath('/home/vishwesh/monai_tutorial_testing/issue_467')

    if os.path.exists(logdir_path)==False:
        os.mkdir(logdir_path)

    # Load Json & Append Root Path
    with open(json_Path, 'r') as json_f:
        json_Data = json.load(json_f)

    train_Data = json_Data['training']
    val_Data = json_Data['validation']

    for idx, each_d in enumerate(train_Data):
        train_Data[idx]['image'] = os.path.join(data_Root, train_Data[idx]['image'])

    for idx, each_d in enumerate(val_Data):
        val_Data[idx]['image'] = os.path.join(data_Root, val_Data[idx]['image'])

    print('Total Number of Training Data Samples: {}'.format(len(train_Data)))
    print(train_Data)
    print('#' * 10)
    print('Total Number of Validation Data Samples: {}'.format(len(val_Data)))
    print(val_Data)
    print('#' * 10)

    # Set Determinism
    set_determinism(seed=123)

    # Define Training Transforms
    train_Transforms = Compose(
        [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        Spacingd(keys=["image"], pixdim=(
            2.0, 2.0, 2.0), mode=("bilinear")),
        ScaleIntensityRanged(
            keys=["image"], a_min=-57, a_max=164,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image"], source_key="image"),
        SpatialPadd(keys=["image"], spatial_size=(96, 96, 96)),
        RandSpatialCropSamplesd(keys=["image"], roi_size=(96, 96, 96), random_size=False, num_samples=2),
        CopyItemsd(keys=["image"], times=2, names=["gt_image", "image_2"], allow_missing_keys=False),
        OneOf(transforms=[
            RandCoarseDropoutd(keys=["image"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True,
                               max_spatial_size=32),
            RandCoarseDropoutd(keys=["image"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False,
                               max_spatial_size=64),
            ]
        ),
        RandCoarseShuffled(keys=["image"], prob=0.8, holes=10, spatial_size=8),
        # Please note that that if image, image_2 are called via the same transform call because of the determinism
        # they will get augmented the exact same way which is not the required case here, hence two calls are made
        OneOf(transforms=[
            RandCoarseDropoutd(keys=["image_2"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True,
                               max_spatial_size=32),
            RandCoarseDropoutd(keys=["image_2"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False,
                               max_spatial_size=64),
        ]
        ),
        RandCoarseShuffled(keys=["image_2"], prob=0.8, holes=10, spatial_size=8)
        ]
    )

    check_ds = Dataset(data=train_Data, transform=train_Transforms)
    check_loader = DataLoader(check_ds, batch_size=1)
    check_data = first(check_loader)
    image = (check_data["image"][0][0])
    print(f"image shape: {image.shape}")

    # Define Network ViT backbone & Loss & Optimizer
    device = torch.device("cuda:0")
    model = ViTAutoEnc(
                in_channels=1,
                img_size=(96, 96, 96),
                patch_size=(16, 16, 16),
                pos_embed='conv',
                hidden_size=768,
                mlp_dim=3072,
    )

    model = model.to(device)

    # Define Hyper-paramters for training loop
    max_epochs = 500
    val_interval = 2
    batch_size = 4
    lr = 1e-4
    epoch_loss_values = []
    step_loss_values = []
    epoch_cl_loss_values = []
    epoch_recon_loss_values = []
    val_loss_values = []
    best_val_loss = 1000.0

    recon_loss = L1Loss()
    contrastive_loss = ContrastiveLoss(batch_size=batch_size*2, temperature=0.05)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Define DataLoader using MONAI, CacheDataset needs to be used
    train_ds = Dataset(data=train_Data, transform=train_Transforms)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

    val_ds = Dataset(data=val_Data, transform=train_Transforms)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=4)

    for epoch in range(max_epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss = 0
        epoch_cl_loss = 0
        epoch_recon_loss = 0
        step = 0

        for batch_data in train_loader:
            step += 1
            start_time = time.time()

            inputs, inputs_2, gt_input = (
                batch_data["image"].to(device),
                batch_data["image_2"].to(device),
                batch_data["gt_image"].to(device),
            )
            optimizer.zero_grad()
            outputs_v1, hidden_v1 = model(inputs)
            outputs_v2, hidden_v2 = model(inputs_2)

            flat_out_v1 = outputs_v1.flatten(start_dim=1, end_dim=4)
            flat_out_v2 = outputs_v2.flatten(start_dim=1, end_dim=4)

            r_loss = recon_loss(outputs_v1, gt_input)
            cl_loss = contrastive_loss(flat_out_v1, flat_out_v2)

            # Adjust the CL loss by Recon Loss
            total_loss = r_loss + cl_loss * r_loss

            total_loss.backward()
            optimizer.step()
            epoch_loss += total_loss.item()
            step_loss_values.append(total_loss.item())

            # CL & Recon Loss Storage of Value
            epoch_cl_loss += cl_loss.item()
            epoch_recon_loss += r_loss.item()

            end_time = time.time()
            print(
                f"{step}/{len(train_ds) // train_loader.batch_size}, "
                f"train_loss: {total_loss.item():.4f}, "
                f"time taken: {end_time-start_time}s")

        epoch_loss /= step
        epoch_cl_loss /= step
        epoch_recon_loss /= step

        epoch_loss_values.append(epoch_loss)
        epoch_cl_loss_values.append(epoch_cl_loss)
        epoch_recon_loss_values.append(epoch_recon_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if epoch % val_interval == 0:
            print('Entering Validation for epoch: {}'.format(epoch+1))
            total_val_loss = 0
            val_step = 0
            model.eval()
            for val_batch in val_loader:
                val_step += 1
                start_time = time.time()
                inputs, gt_input = (
                    val_batch["image"].to(device),
                    val_batch["gt_image"].to(device),
                )
                print('Input shape: {}'.format(inputs.shape))
                outputs, outputs_v2 = model(inputs)
                val_loss = recon_loss(outputs, gt_input)
                total_val_loss += val_loss.item()
                end_time = time.time()

            total_val_loss /= val_step
            val_loss_values.append(total_val_loss)
            print(f"epoch {epoch + 1} Validation average loss: {total_val_loss:.4f}, " f"time taken: {end_time-start_time}s")

            if total_val_loss < best_val_loss:
                print(f"Saving new model based on validation loss {total_val_loss:.4f}")
                best_val_loss = total_val_loss
                checkpoint = {'epoch': max_epochs,
                              'state_dict': model.state_dict(),
                              'optimizer': optimizer.state_dict()
                              }
                torch.save(checkpoint, os.path.join(logdir_path, 'best_model.pt'))

            plt.figure(1, figsize=(8, 8))
            plt.subplot(2, 2, 1)
            plt.plot(epoch_loss_values)
            plt.grid()
            plt.title('Training Loss')

            plt.subplot(2, 2, 2)
            plt.plot(val_loss_values)
            plt.grid()
            plt.title('Validation Loss')

            plt.subplot(2, 2, 3)
            plt.plot(epoch_cl_loss_values)
            plt.grid()
            plt.title('Training Contrastive Loss')

            plt.subplot(2, 2, 4)
            plt.plot(epoch_recon_loss_values)
            plt.grid()
            plt.title('Training Recon Loss')

            plt.savefig(os.path.join(logdir_path, 'loss_plots.png'))
            plt.close(1)

    print('Done')
    return None
Ejemplo n.º 29
0
def main(tempdir):
    print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, _ = create_test_image_3d(128,
                                     128,
                                     128,
                                     num_seg_classes=1,
                                     channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    files = [{"img": img} for img in images]

    # define pre transforms
    pre_transforms = Compose([
        LoadImaged(keys="img"),
        EnsureChannelFirstd(keys="img"),
        Orientationd(keys="img", axcodes="RAS"),
        Resized(keys="img",
                spatial_size=(96, 96, 96),
                mode="trilinear",
                align_corners=True),
        ScaleIntensityd(keys="img"),
        EnsureTyped(keys="img"),
    ])
    # define dataset and dataloader
    dataset = Dataset(data=files, transform=pre_transforms)
    dataloader = DataLoader(dataset, batch_size=2, num_workers=4)
    # define post transforms
    post_transforms = Compose([
        EnsureTyped(keys="pred"),
        Activationsd(keys="pred", sigmoid=True),
        Invertd(
            keys=
            "pred",  # invert the `pred` data field, also support multiple fields
            transform=pre_transforms,
            orig_keys=
            "img",  # get the previously applied pre_transforms information on the `img` data field,
            # then invert `pred` based on this information. we can use same info
            # for multiple fields, also support different orig_keys for different fields
            meta_keys=
            "pred_meta_dict",  # key field to save inverted meta data, every item maps to `keys`
            orig_meta_keys=
            "img_meta_dict",  # get the meta data from `img_meta_dict` field when inverting,
            # for example, may need the `affine` to invert `Spacingd` transform,
            # multiple fields can use the same meta data to invert
            meta_key_postfix=
            "meta_dict",  # if `meta_keys=None`, use "{keys}_{meta_key_postfix}" as the meta key,
            # if `orig_meta_keys=None`, use "{orig_keys}_{meta_key_postfix}",
            # otherwise, no need this arg during inverting
            nearest_interp=
            False,  # don't change the interpolation mode to "nearest" when inverting transforms
            # to ensure a smooth output, then execute `AsDiscreted` transform
            to_tensor=True,  # convert to PyTorch Tensor after inverting
        ),
        AsDiscreted(keys="pred", threshold=0.5),
        SaveImaged(keys="pred",
                   meta_keys="pred_meta_dict",
                   output_dir="./out",
                   output_postfix="seg",
                   resample=False),
    ])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    net.load_state_dict(
        torch.load("best_metric_model_segmentation3d_dict.pth"))

    net.eval()
    with torch.no_grad():
        for d in dataloader:
            images = d["img"].to(device)
            # define sliding window size and batch size for windows inference
            d["pred"] = sliding_window_inference(inputs=images,
                                                 roi_size=(96, 96, 96),
                                                 sw_batch_size=4,
                                                 predictor=net)
            # decollate the batch data into a list of dictionaries, then execute postprocessing transforms
            d = [post_transforms(i) for i in decollate_batch(d)]