Ejemplo n.º 1
0
def get_data(keys):
    """Get the example data to be used.

    Use MarsAtlas as it only contains 1 image for quick download and
    that image is parcellated.
    """
    cache_dir = os.environ.get("MONAI_DATA_DIRECTORY") or tempfile.mkdtemp()
    fname = "MarsAtlas-MNI-Colin27.zip"
    url = "https://www.dropbox.com/s/ndz8qtqblkciole/" + fname + "?dl=1"
    out_path = os.path.join(cache_dir, "MarsAtlas-MNI-Colin27")
    zip_path = os.path.join(cache_dir, fname)

    download_and_extract(url, zip_path, out_path)

    image, label = sorted(glob(os.path.join(out_path, "*.nii")))

    data = {CommonKeys.IMAGE: image, CommonKeys.LABEL: label}

    transforms = Compose([
        LoadImaged(keys),
        AddChanneld(keys),
        ScaleIntensityd(CommonKeys.IMAGE),
        Rotate90d(keys, spatial_axes=[0, 2])
    ])
    data = transforms(data)
    max_size = max(data[keys[0]].shape)
    padder = SpatialPadd(keys, (max_size, max_size, max_size))
    return padder(data)
def run_inference_test(root_dir, device=torch.device("cuda:0")):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        # resampling with align_corners=True or dtype=float64 will generate
        # slight different results between PyTorch 1.5 an 1.6
        Spacingd(keys=["img", "seg"],
                 pixdim=[1.2, 0.8, 0.7],
                 mode=["bilinear", "nearest"],
                 dtype=np.float32),
        ScaleIntensityd(keys=["img", "seg"]),
        ToTensord(keys=["img", "seg"]),
    ])
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inferene need to input 1 image in every iteration
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    dice_metric = DiceMetric(include_background=True,
                             to_onehot_y=False,
                             sigmoid=True,
                             reduction="mean")

    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    model.load_state_dict(torch.load(model_filename))
    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        # resampling with align_corners=True or dtype=float64 will generate
        # slight different results between PyTorch 1.5 an 1.6
        saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"),
                           dtype=np.float32)
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            sw_batch_size, roi_size = 4, (96, 96, 96)
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            value = dice_metric(y_pred=val_outputs, y=val_labels)
            not_nans = dice_metric.not_nans.item()
            metric_count += not_nans
            metric_sum += value.item() * not_nans
            val_outputs = (val_outputs.sigmoid() >= 0.5).float()
            saver.save_batch(val_outputs, val_data["img_meta_dict"])
        metric = metric_sum / metric_count
    return metric
Ejemplo n.º 3
0
    def test_samples(self):
        testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data")
        keys = "image"
        xforms = Compose(
            [
                LoadImaged(keys=keys),
                AddChanneld(keys=keys),
                ScaleIntensityd(keys=keys),
                RandSpatialCropSamplesd(keys=keys, roi_size=(8, 8, 5), random_size=True, num_samples=10),
            ]
        )
        image_path = os.path.join(testing_dir, "anatomical.nii")
        xforms.set_random_state(0)
        ims = xforms({keys: image_path})
        fig, mat = matshow3d(
            [im[keys] for im in ims], title=f"testing {keys}", figsize=(2, 2), frames_per_row=5, every_n=2, show=False
        )
        self.assertTrue(mat.dtype == np.float32)

        with tempfile.TemporaryDirectory() as tempdir:
            tempimg = f"{tempdir}/matshow3d_patch_test.png"
            fig.savefig(tempimg)
            comp = compare_images(f"{testing_dir}/matshow3d_patch_test.png", tempimg, 5e-2, in_decorator=True)
            if comp:
                print("not none comp: ", comp)  # matplotlib 3.2.2
                np.testing.assert_allclose(comp["rms"], 30.786983, atol=1e-3, rtol=1e-3)
            else:
                self.assertIsNone(comp, f"value of comp={comp}")  # None indicates test passed
Ejemplo n.º 4
0
 def test_factor_scale(self):
     key = "img"
     for p in TEST_NDARRAYS:
         scaler = ScaleIntensityd(keys=[key], minv=None, maxv=None, factor=0.1)
         result = scaler({key: p(self.imt)})
         expected = (self.imt * (1 + 0.1)).astype(np.float32)
         assert_allclose(result[key], p(expected), type_test="tensor")
Ejemplo n.º 5
0
    def _get_loader(self, folders):
        images = []
        segs = []
        for folder in folders:
            images += glob(os.path.join(folder, "*_im.nii.gz"))
            segs += glob(os.path.join(folder, "*_seg.nii.gz"))
        images = sorted(images, key=os.path.basename)
        segs = sorted(segs, key=os.path.basename)

        files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

        transforms = Compose([
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            ScaleIntensityd(keys="img"),
            ToTensord(keys=["img", "seg"]),
        ])

        ds = CacheDataset(data=files, transform=transforms)
        loader = DataLoader(ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate)

        return loader
Ejemplo n.º 6
0
    def test_3d_rgb(self):
        testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data")
        keys = "image"
        xforms = Compose(
            [
                LoadImaged(keys=keys),
                AddChanneld(keys=keys),
                ScaleIntensityd(keys=keys),
                # change to RGB color image
                RepeatChanneld(keys=keys, repeats=3),
            ]
        )
        image_path = os.path.join(testing_dir, "anatomical.nii")
        ims = xforms({keys: image_path})

        fig = pyplot.figure()  # external figure
        fig, _ = matshow3d(
            volume=ims[keys],
            fig=fig,
            figsize=(2, 2),
            frames_per_row=5,
            every_n=2,
            frame_dim=-1,
            channel_dim=0,
            show=False,
        )

        with tempfile.TemporaryDirectory() as tempdir:
            tempimg = f"{tempdir}/matshow3d_rgb_test.png"
            fig.savefig(tempimg)
            comp = compare_images(f"{testing_dir}/matshow3d_rgb_test.png", tempimg, 5e-2)
            self.assertIsNone(comp, f"value of comp={comp}")  # None indicates test passed
Ejemplo n.º 7
0
def run_inference_test(root_dir, device=torch.device("cuda:0")):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys=["img", "seg"]),
        ToTensord(keys=["img", "seg"]),
    ])
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inferene need to input 1 image in every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate,
                            pin_memory=torch.cuda.is_available())

    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    model.load_state_dict(torch.load(model_filename))
    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"),
                           dtype=int)
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            sw_batch_size, roi_size = 4, (96, 96, 96)
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            value = compute_meandice(y_pred=val_outputs,
                                     y=val_labels,
                                     include_background=True,
                                     to_onehot_y=False,
                                     add_sigmoid=True)
            metric_count += len(value)
            metric_sum += value.sum().item()
            val_outputs = (val_outputs.sigmoid() >= 0.5).float()
            saver.save_batch(
                val_outputs, {
                    "filename_or_obj": val_data["img.filename_or_obj"],
                    "affine": val_data["img.affine"]
                })
        metric = metric_sum / metric_count
    return metric
def run_inference_test(root_dir, device="cuda:0"):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            EnsureChannelFirstd(keys=["img", "seg"]),
            # resampling with align_corners=True or dtype=float64 will generate
            # slight different results between PyTorch 1.5 an 1.6
            Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
            ScaleIntensityd(keys="img"),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inference need to input 1 image in every iteration
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)

    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    model.load_state_dict(torch.load(model_filename))
    with eval_mode(model):
        # resampling with align_corners=True or dtype=float64 will generate
        # slight different results between PyTorch 1.5 an 1.6
        saver = SaveImage(
            output_dir=os.path.join(root_dir, "output"),
            dtype=np.float32,
            output_ext=".nii.gz",
            output_postfix="seg",
            mode="bilinear",
        )
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            sw_batch_size, roi_size = 4, (96, 96, 96)
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            # decollate prediction into a list
            val_outputs = [val_post_tran(i) for i in decollate_batch(val_outputs)]
            val_meta = decollate_batch(val_data[PostFix.meta("img")])
            # compute metrics
            dice_metric(y_pred=val_outputs, y=val_labels)
            for img, meta in zip(val_outputs, val_meta):  # save a decollated batch of files
                saver(img, meta)

    return dice_metric.aggregate().item()
Ejemplo n.º 9
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
        Image.fromarray(im.astype("uint8")).save(os.path.join(tempdir, f"img{i:d}.png"))
        Image.fromarray(seg.astype("uint8")).save(os.path.join(tempdir, f"seg{i:d}.png"))

    images = sorted(glob(os.path.join(tempdir, "img*.png")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.png")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            AddChanneld(keys=["img", "seg"]),
            ScaleIntensityd(keys="img"),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
    dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(torch.load("best_metric_model_segmentation2d_dict.pth"))

    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        saver = PNGSaver(output_dir="./output")
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            value = dice_metric(y_pred=val_outputs, y=val_labels)
            metric_count += len(value)
            metric_sum += value.item() * len(value)
            val_outputs = val_outputs.sigmoid() >= 0.5
            saver.save_batch(val_outputs)
        metric = metric_sum / metric_count
        print("evaluation metric:", metric)
Ejemplo n.º 10
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    images = [
        os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI607-Guys-1097-T1.nii.gz"]),
        os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI175-HH-1570-T1.nii.gz"]),
        os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI385-HH-2078-T1.nii.gz"]),
        os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI344-Guys-0905-T1.nii.gz"]),
        os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI409-Guys-0960-T1.nii.gz"]),
        os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI584-Guys-1129-T1.nii.gz"]),
        os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI253-HH-1694-T1.nii.gz"]),
        os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI092-HH-1436-T1.nii.gz"]),
        os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI574-IOP-1156-T1.nii.gz"]),
        os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI585-Guys-1130-T1.nii.gz"]),
    ]

    # 2 binary labels for gender classification: man and woman
    labels = np.array([0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)
    val_files = [{"img": img, "label": label} for img, label in zip(images, labels)]

    # Define transforms for image
    val_transforms = Compose(
        [
            LoadNiftid(keys=["img"]),
            AddChanneld(keys=["img"]),
            ScaleIntensityd(keys=["img"]),
            Resized(keys=["img"], spatial_size=(96, 96, 96)),
            ToTensord(keys=["img"]),
        ]
    )

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available())

    # Create DenseNet121
    device = torch.device("cuda:0")
    model = monai.networks.nets.densenet.densenet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)

    model.load_state_dict(torch.load("best_metric_model.pth"))
    model.eval()
    with torch.no_grad():
        num_correct = 0.0
        metric_count = 0
        saver = CSVSaver(output_dir="./output")
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(device), val_data["label"].to(device)
            val_outputs = model(val_images).argmax(dim=1)
            value = torch.eq(val_outputs, val_labels)
            metric_count += len(value)
            num_correct += value.sum().item()
            saver.save_batch(val_outputs, val_data["img_meta_dict"])
        metric = num_correct / metric_count
        print("evaluation metric:", metric)
        saver.finalize()
Ejemplo n.º 11
0
    def test_values(self):
        testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   "testing_data")
        transform = Compose([
            LoadImaged(keys="image"),
            AddChanneld(keys="image"),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ])

        def _test_dataset(dataset):
            self.assertEqual(
                len(dataset),
                int(MEDNIST_FULL_DATASET_LENGTH * dataset.test_frac))
            self.assertTrue("image" in dataset[0])
            self.assertTrue("label" in dataset[0])
            self.assertTrue(PostFix.meta("image") in dataset[0])
            self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64))

        with skip_if_downloading_fails():
            data = MedNISTDataset(root_dir=testing_dir,
                                  transform=transform,
                                  section="test",
                                  download=True,
                                  copy_cache=False)

        _test_dataset(data)

        # testing from
        data = MedNISTDataset(root_dir=Path(testing_dir),
                              transform=transform,
                              section="test",
                              download=False)
        self.assertEqual(data.get_num_classes(), 6)
        _test_dataset(data)
        data = MedNISTDataset(root_dir=testing_dir,
                              section="test",
                              download=False)
        self.assertTupleEqual(data[0]["image"].shape, (64, 64))
        # test same dataset length with different random seed
        data = MedNISTDataset(root_dir=testing_dir,
                              transform=transform,
                              section="test",
                              download=False,
                              seed=42)
        _test_dataset(data)
        self.assertEqual(data[0]["class_name"], "AbdomenCT")
        self.assertEqual(data[0]["label"].cpu().item(), 0)
        shutil.rmtree(os.path.join(testing_dir, "MedNIST"))
        try:
            MedNISTDataset(root_dir=testing_dir,
                           transform=transform,
                           section="test",
                           download=False)
        except RuntimeError as e:
            print(str(e))
            self.assertTrue(str(e).startswith("Cannot find dataset directory"))
Ejemplo n.º 12
0
 def test_range_scale(self):
     key = 'img'
     scaler = ScaleIntensityd(keys=[key], minv=1.0, maxv=2.0)
     result = scaler({key: self.imt})
     mina = np.min(self.imt)
     maxa = np.max(self.imt)
     norm = (self.imt - mina) / (maxa - mina)
     expected = (norm * (2.0 - 1.0)) + 1.0
     np.testing.assert_allclose(result[key], expected)
Ejemplo n.º 13
0
    def test_values(self):
        testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   "testing_data")
        transform = Compose([
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ])

        def _test_dataset(dataset):
            self.assertEqual(len(dataset), 52)
            self.assertTrue("image" in dataset[0])
            self.assertTrue("label" in dataset[0])
            self.assertTrue("image_meta_dict" in dataset[0])
            self.assertTupleEqual(dataset[0]["image"].shape, (1, 33, 47, 34))

        try:  # will start downloading if testing_dir doesn't have the Decathlon files
            data = DecathlonDataset(
                root_dir=testing_dir,
                task="Task04_Hippocampus",
                transform=transform,
                section="validation",
                download=True,
            )
        except (ContentTooShortError, HTTPError, RuntimeError) as e:
            print(str(e))
            if isinstance(e, RuntimeError):
                # FIXME: skip MD5 check as current downloading method may fail
                self.assertTrue(str(e).startswith("MD5 check"))
            return  # skipping this test due the network connection errors

        _test_dataset(data)
        data = DecathlonDataset(root_dir=testing_dir,
                                task="Task04_Hippocampus",
                                transform=transform,
                                section="validation",
                                download=False)
        _test_dataset(data)
        data = DecathlonDataset(root_dir=testing_dir,
                                task="Task04_Hippocampus",
                                section="validation",
                                download=False)
        self.assertTupleEqual(data[0]["image"].shape, (33, 47, 34))
        shutil.rmtree(os.path.join(testing_dir, "Task04_Hippocampus"))
        try:
            data = DecathlonDataset(
                root_dir=testing_dir,
                task="Task04_Hippocampus",
                transform=transform,
                section="validation",
                download=False,
            )
        except RuntimeError as e:
            print(str(e))
            self.assertTrue(str(e).startswith("Cannot find dataset directory"))
Ejemplo n.º 14
0
 def test_range_scale(self):
     key = "img"
     for p in TEST_NDARRAYS:
         scaler = ScaleIntensityd(keys=[key], minv=1.0, maxv=2.0)
         result = scaler({key: p(self.imt)})
         mina = np.min(self.imt)
         maxa = np.max(self.imt)
         norm = (self.imt - mina) / (maxa - mina)
         expected = (norm * (2.0 - 1.0)) + 1.0
         assert_allclose(result[key], p(expected), type_test="tensor")
Ejemplo n.º 15
0
def get_image_transforms():
    itk_reader = monai.data.ITKReader()
    # Define transforms for image
    image_transforms = Compose([
        LoadImaged(keys=['img'], reader=itk_reader),
        EnsureChannelFirstd(keys=['img']),
        ScaleIntensityd(keys=['img']),
        ToTensord(keys=['img']),
    ])
    return image_transforms
Ejemplo n.º 16
0
    def test_values(self):
        tempdir = tempfile.mkdtemp()
        transform = Compose([
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ])

        def _test_dataset(dataset):
            self.assertEqual(len(dataset), 52)
            self.assertTrue("image" in dataset[0])
            self.assertTrue("label" in dataset[0])
            self.assertTrue("image_meta_dict" in dataset[0])
            self.assertTupleEqual(dataset[0]["image"].shape, (1, 33, 47, 34))

        try:
            data = DecathlonDataset(root_dir=tempdir,
                                    task="Task04_Hippocampus",
                                    transform=transform,
                                    section="validation",
                                    download=True)
        except RuntimeError as e:
            if str(e).startswith(
                    "download failed due to network issue or permission denied."
            ):
                shutil.rmtree(tempdir)
                return

        _test_dataset(data)
        data = DecathlonDataset(root_dir=tempdir,
                                task="Task04_Hippocampus",
                                transform=transform,
                                section="validation",
                                download=False)
        _test_dataset(data)
        data = DecathlonDataset(root_dir=tempdir,
                                task="Task04_Hippocampus",
                                section="validation",
                                download=False)
        self.assertTupleEqual(data[0]["image"].shape, (33, 47, 34))
        shutil.rmtree(os.path.join(tempdir, "Task04_Hippocampus"))
        try:
            data = DecathlonDataset(root_dir=tempdir,
                                    task="Task04_Hippocampus",
                                    transform=transform,
                                    section="validation",
                                    download=False)
        except RuntimeError as e:
            print(str(e))
            self.assertTrue(
                str(e).startswith("can not find dataset directory"))

        shutil.rmtree(tempdir)
Ejemplo n.º 17
0
    def test_values(self):
        testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   "testing_data")
        train_transform = Compose([
            LoadImaged(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ])
        val_transform = LoadImaged(keys=["image", "label"])

        def _test_dataset(dataset):
            self.assertEqual(len(dataset), 52)
            self.assertTrue("image" in dataset[0])
            self.assertTrue("label" in dataset[0])
            self.assertTrue("image_meta_dict" in dataset[0])
            self.assertTupleEqual(dataset[0]["image"].shape, (1, 34, 49, 41))

        cvdataset = CrossValidation(
            dataset_cls=DecathlonDataset,
            nfolds=5,
            seed=12345,
            root_dir=testing_dir,
            task="Task04_Hippocampus",
            section="validation",
            transform=train_transform,
            download=True,
        )

        try:  # will start downloading if testing_dir doesn't have the Decathlon files
            data = cvdataset.get_dataset(folds=0)
        except (ContentTooShortError, HTTPError, RuntimeError) as e:
            print(str(e))
            if isinstance(e, RuntimeError):
                # FIXME: skip MD5 check as current downloading method may fail
                self.assertTrue(str(e).startswith("md5 check"))
            return  # skipping this test due the network connection errors

        _test_dataset(data)

        # test training data for fold [1, 2, 3, 4] of 5 splits
        data = cvdataset.get_dataset(folds=[1, 2, 3, 4])
        self.assertTupleEqual(data[0]["image"].shape, (1, 35, 52, 33))
        self.assertEqual(len(data), 208)
        # test train / validation for fold 4 of 5 splits
        data = cvdataset.get_dataset(folds=[4],
                                     transform=val_transform,
                                     download=False)
        # val_transform doesn't add the channel dim to shape
        self.assertTupleEqual(data[0]["image"].shape, (38, 53, 30))
        self.assertEqual(len(data), 52)
        data = cvdataset.get_dataset(folds=[0, 1, 2, 3])
        self.assertTupleEqual(data[0]["image"].shape, (1, 34, 49, 41))
        self.assertEqual(len(data), 208)
Ejemplo n.º 18
0
def get_test_transforms(end_image_shape):
    test_transforms = Compose([
        LoadNiftid(keys=["image"]),
        AddChanneld(keys=["image"]),
        Orientationd(keys=["image"], axcodes="RAS"),
        Winsorized(keys=["image"]),
        NormalizeIntensityd(keys=["image"]),
        ScaleIntensityd(keys=["image"]),
        SpatialPadd(keys=["image"], spatial_size=end_image_shape),
        ToTensord(keys=["image"]),
    ])
    return test_transforms
Ejemplo n.º 19
0
 def test_channel_wise(self):
     key = "img"
     for p in TEST_NDARRAYS:
         scaler = ScaleIntensityd(keys=[key], minv=1.0, maxv=2.0, channel_wise=True)
         data = p(self.imt)
         result = scaler({key: data})
         mina = self.imt.min()
         maxa = self.imt.max()
         for i, c in enumerate(data):
             norm = (c - mina) / (maxa - mina)
             expected = p((norm * (2.0 - 1.0)) + 1.0)
             assert_allclose(result[key][i], expected, type_test="tensor", rtol=1e-7, atol=0)
Ejemplo n.º 20
0
    def setUp(self):

        self.root_dir = os.environ.get("MONAI_DATA_DIRECTORY")
        if not self.root_dir:
            self.root_dir = os.path.join(
                os.path.dirname(os.path.realpath(__file__)), "testing_data")

        self.transforms = Compose([
            LoadImaged(keys="image"),
            AddChanneld(keys="image"),
            ScaleIntensityd(keys="image"),
            ToTensord(keys="image"),
        ])
Ejemplo n.º 21
0
    def test_values(self):
        testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   "testing_data")
        transform = Compose([
            LoadImaged(keys="image"),
            AddChanneld(keys="image"),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ])

        def _test_dataset(dataset):
            self.assertEqual(len(dataset), 5986)
            self.assertTrue("image" in dataset[0])
            self.assertTrue("label" in dataset[0])
            self.assertTrue("image_meta_dict" in dataset[0])
            self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64))

        try:  # will start downloading if testing_dir doesn't have the MedNIST files
            data = MedNISTDataset(root_dir=testing_dir,
                                  transform=transform,
                                  section="test",
                                  download=True)
        except (ContentTooShortError, HTTPError, RuntimeError) as e:
            print(str(e))
            if isinstance(e, RuntimeError):
                # FIXME: skip MD5 check as current downloading method may fail
                self.assertTrue(str(e).startswith("md5 check"))
            return  # skipping this test due the network connection errors

        _test_dataset(data)

        # testing from
        data = MedNISTDataset(root_dir=testing_dir,
                              transform=transform,
                              section="test",
                              download=False)
        data.get_num_classes()
        _test_dataset(data)
        data = MedNISTDataset(root_dir=testing_dir,
                              section="test",
                              download=False)
        self.assertTupleEqual(data[0]["image"].shape, (64, 64))
        shutil.rmtree(os.path.join(testing_dir, "MedNIST"))
        try:
            data = MedNISTDataset(root_dir=testing_dir,
                                  transform=transform,
                                  section="test",
                                  download=False)
        except RuntimeError as e:
            print(str(e))
            self.assertTrue(str(e).startswith("Cannot find dataset directory"))
Ejemplo n.º 22
0
    def test_values(self):
        testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   "testing_data")
        train_transform = Compose([
            LoadImaged(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ])
        val_transform = LoadImaged(keys=["image", "label"])

        def _test_dataset(dataset):
            self.assertEqual(len(dataset), 52)
            self.assertTrue("image" in dataset[0])
            self.assertTrue("label" in dataset[0])
            self.assertTrue(PostFix.meta("image") in dataset[0])
            self.assertTupleEqual(dataset[0]["image"].shape, (1, 34, 49, 41))

        cvdataset = CrossValidation(
            dataset_cls=DecathlonDataset,
            nfolds=5,
            seed=12345,
            root_dir=testing_dir,
            task="Task04_Hippocampus",
            section="validation",
            transform=train_transform,
            download=True,
        )

        with skip_if_downloading_fails():
            data = cvdataset.get_dataset(folds=0)

        _test_dataset(data)

        # test training data for fold [1, 2, 3, 4] of 5 splits
        data = cvdataset.get_dataset(folds=[1, 2, 3, 4])
        self.assertTupleEqual(data[0]["image"].shape, (1, 35, 52, 33))
        self.assertEqual(len(data), 208)
        # test train / validation for fold 4 of 5 splits
        data = cvdataset.get_dataset(folds=[4],
                                     transform=val_transform,
                                     download=False)
        # val_transform doesn't add the channel dim to shape
        self.assertTupleEqual(data[0]["image"].shape, (38, 53, 30))
        self.assertEqual(len(data), 52)
        data = cvdataset.get_dataset(folds=[0, 1, 2, 3])
        self.assertTupleEqual(data[0]["image"].shape, (1, 34, 49, 41))
        self.assertEqual(len(data), 208)
Ejemplo n.º 23
0
    def test_values(self):
        tempdir = tempfile.mkdtemp()
        transform = Compose([
            LoadPNGd(keys="image"),
            AddChanneld(keys="image"),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ])

        def _test_dataset(dataset):
            self.assertEqual(len(dataset), 5986)
            self.assertTrue("image" in dataset[0])
            self.assertTrue("label" in dataset[0])
            self.assertTrue("image_meta_dict" in dataset[0])
            self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64))

        data = MedNISTDataset(root_dir=tempdir,
                              transform=transform,
                              section="test",
                              download=True)
        _test_dataset(data)
        data = MedNISTDataset(root_dir=tempdir,
                              transform=transform,
                              section="test",
                              download=False)
        _test_dataset(data)
        data = MedNISTDataset(root_dir=tempdir, section="test", download=False)
        self.assertTupleEqual(data[0]["image"].shape, (64, 64))
        shutil.rmtree(os.path.join(tempdir, "MedNIST"))
        try:
            data = MedNISTDataset(root_dir=tempdir,
                                  transform=transform,
                                  section="test",
                                  download=False)
        except RuntimeError as e:
            print(str(e))
            self.assertTrue(
                str(e).startswith("can not find dataset directory"))

        shutil.rmtree(tempdir)
Ejemplo n.º 24
0
    def _get_loader(self, folders):
        images = []
        segs = []
        for folder in folders:
            images += glob(os.path.join(folder, "*_im.nii.gz"))
            segs += glob(os.path.join(folder, "*_seg.nii.gz"))
        images = sorted(images, key=os.path.basename)
        segs = sorted(segs, key=os.path.basename)

        files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

        transforms = Compose([
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            ScaleIntensityd(keys="img"),
            RandCropByPosNegLabeld(keys=["img", "seg"],
                                   label_key="seg",
                                   spatial_size=[96, 96, 96],
                                   pos=1,
                                   neg=1,
                                   num_samples=4),
            RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
            ToTensord(keys=["img", "seg"]),
        ])

        ds = CacheDataset(data=files, transform=transforms)
        loader = DataLoader(
            ds,
            batch_size=2,
            shuffle=True,
            num_workers=4,
            collate_fn=list_data_collate,
            pin_memory=torch.cuda.is_available(),
        )

        return loader
Ejemplo n.º 25
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)]

    # model file path
    model_file = glob("./runs/net_key_metric*")[0]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointLoader(load_path=model_file, load_dict={"net": net}),
        SegmentationSaver(
            output_dir="./runs/",
            batch_transform=lambda batch: batch["image_meta_dict"],
            output_transform=lambda output: output["pred"],
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))},
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False,
    )
    evaluator.run()
Ejemplo n.º 26
0
    def test_invert(self):
        set_determinism(seed=0)
        im_fname, seg_fname = (
            make_nifti_image(i)
            for i in create_test_image_3d(101, 100, 107, noise_max=100))
        transform = Compose([
            LoadImaged(KEYS),
            AddChanneld(KEYS),
            Orientationd(KEYS, "RPS"),
            Spacingd(KEYS,
                     pixdim=(1.2, 1.01, 0.9),
                     mode=["bilinear", "nearest"],
                     dtype=np.float32),
            ScaleIntensityd("image", minv=1, maxv=10),
            RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(KEYS, prob=0.5),
            RandRotate90d(KEYS, spatial_axes=(1, 2)),
            RandZoomd(KEYS,
                      prob=0.5,
                      min_zoom=0.5,
                      max_zoom=1.1,
                      keep_size=True),
            RandRotated(KEYS,
                        prob=0.5,
                        range_x=np.pi,
                        mode="bilinear",
                        align_corners=True,
                        dtype=np.float64),
            RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"),
            ResizeWithPadOrCropd(KEYS, 100),
            # test EnsureTensor for complicated dict data and invert it
            CopyItemsd(PostFix.meta("image"), times=1, names="test_dict"),
            # test to support Tensor, Numpy array and dictionary when inverting
            EnsureTyped(keys=["image", "test_dict"]),
            ToTensord("image"),
            CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]),
            CopyItemsd("label",
                       times=2,
                       names=["label_inverted", "label_inverted1"]),
            CopyItemsd("image",
                       times=2,
                       names=["image_inverted", "image_inverted1"]),
        ])
        data = [{"image": im_fname, "label": seg_fname} for _ in range(12)]

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available(
        ) else 2

        dataset = CacheDataset(data, transform=transform, progress=False)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=5)
        inverter = Invertd(
            # `image` was not copied, invert the original value directly
            keys=["image_inverted", "label_inverted", "test_dict"],
            transform=transform,
            orig_keys=["label", "label", "test_dict"],
            meta_keys=[
                PostFix.meta("image_inverted"),
                PostFix.meta("label_inverted"), None
            ],
            orig_meta_keys=[
                PostFix.meta("label"),
                PostFix.meta("label"), None
            ],
            nearest_interp=True,
            to_tensor=[True, False, False],
            device="cpu",
        )

        inverter_1 = Invertd(
            # `image` was not copied, invert the original value directly
            keys=["image_inverted1", "label_inverted1"],
            transform=transform,
            orig_keys=["image", "image"],
            meta_keys=[
                PostFix.meta("image_inverted1"),
                PostFix.meta("label_inverted1")
            ],
            orig_meta_keys=[PostFix.meta("image"),
                            PostFix.meta("image")],
            nearest_interp=[True, False],
            to_tensor=[True, True],
            device="cpu",
        )

        expected_keys = [
            "image",
            "image_inverted",
            "image_inverted1",
            PostFix.meta("image_inverted1"),
            PostFix.meta("image_inverted"),
            PostFix.meta("image"),
            "image_transforms",
            "label",
            "label_inverted",
            "label_inverted1",
            PostFix.meta("label_inverted1"),
            PostFix.meta("label_inverted"),
            PostFix.meta("label"),
            "label_transforms",
            "test_dict",
            "test_dict_transforms",
        ]
        # execute 1 epoch
        for d in loader:
            d = decollate_batch(d)
            for item in d:
                item = inverter(item)
                item = inverter_1(item)

                self.assertListEqual(sorted(item), expected_keys)
                self.assertTupleEqual(item["image"].shape[1:], (100, 100, 100))
                self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100))
                # check the nearest interpolation mode
                i = item["image_inverted"]
                torch.testing.assert_allclose(
                    i.to(torch.uint8).to(torch.float), i.to(torch.float))
                self.assertTupleEqual(i.shape[1:], (100, 101, 107))
                i = item["label_inverted"]
                torch.testing.assert_allclose(
                    i.to(torch.uint8).to(torch.float), i.to(torch.float))
                self.assertTupleEqual(i.shape[1:], (100, 101, 107))
                # test inverted test_dict
                self.assertTrue(
                    isinstance(item["test_dict"]["affine"], np.ndarray))
                self.assertTrue(
                    isinstance(item["test_dict"]["filename_or_obj"], str))

                # check the case that different items use different interpolation mode to invert transforms
                d = item["image_inverted1"]
                # if the interpolation mode is nearest, accumulated diff should be smaller than 1
                self.assertLess(
                    torch.sum(
                        d.to(torch.float) -
                        d.to(torch.uint8).to(torch.float)).item(), 1.0)
                self.assertTupleEqual(d.shape, (1, 100, 101, 107))

                d = item["label_inverted1"]
                # if the interpolation mode is not nearest, accumulated diff should be greater than 10000
                self.assertGreater(
                    torch.sum(
                        d.to(torch.float) -
                        d.to(torch.uint8).to(torch.float)).item(), 10000.0)
                self.assertTupleEqual(d.shape, (1, 100, 101, 107))

        # check labels match
        reverted = item["label_inverted"].detach().cpu().numpy().astype(
            np.int32)
        original = LoadImaged(KEYS)(data[-1])["label"]
        n_good = np.sum(np.isclose(reverted, original, atol=1e-3))
        reverted_name = item[PostFix.meta("label_inverted")]["filename_or_obj"]
        original_name = data[-1]["label"]
        self.assertEqual(reverted_name, original_name)
        print("invert diff", reverted.size - n_good)
        # 25300: 2 workers (cpu, non-macos)
        # 1812: 0 workers (gpu or macos)
        # 1821: windows torch 1.10.0
        self.assertTrue((reverted.size - n_good) in (34007, 1812, 1821),
                        f"diff.  {reverted.size - n_good}")

        set_determinism(seed=None)
Ejemplo n.º 27
0
    n = nib.Nifti1Image(im, np.eye(4))
    nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))

    n = nib.Nifti1Image(seg, np.eye(4))
    nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))
val_files = [{'img': img, 'seg': seg} for img, seg in zip(images, segs)]

# define transforms for image and segmentation
val_transforms = Compose([
    LoadNiftid(keys=['img', 'seg']),
    AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1),
    ScaleIntensityd(keys=['img', 'seg']),
    ToTensord(keys=['img', 'seg'])
])
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)

device = torch.device("cuda:0")
net = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
)
net.to(device)
Ejemplo n.º 28
0
def run_training_test(root_dir, device="cuda:0"):
    real_images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
    train_files = [{"reals": img} for img in zip(real_images)]

    # prepare real data
    train_transforms = Compose([
        LoadNiftid(keys=["reals"]),
        AsChannelFirstd(keys=["reals"]),
        ScaleIntensityd(keys=["reals"]),
        RandFlipd(keys=["reals"], prob=0.5),
        ToTensord(keys=["reals"]),
    ])
    train_ds = monai.data.CacheDataset(data=train_files,
                                       transform=train_transforms,
                                       cache_rate=0.5)
    train_loader = monai.data.DataLoader(train_ds,
                                         batch_size=2,
                                         shuffle=True,
                                         num_workers=4)

    learning_rate = 2e-4
    betas = (0.5, 0.999)
    real_label = 1
    fake_label = 0

    # create discriminator
    disc_net = Discriminator(in_shape=(1, 64, 64),
                             channels=(8, 16, 32, 64, 1),
                             strides=(2, 2, 2, 2, 1),
                             num_res_units=1,
                             kernel_size=5).to(device)
    disc_net.apply(normal_init)
    disc_opt = torch.optim.Adam(disc_net.parameters(),
                                learning_rate,
                                betas=betas)
    disc_loss_criterion = torch.nn.BCELoss()

    def discriminator_loss(gen_images, real_images):
        real = real_images.new_full((real_images.shape[0], 1), real_label)
        gen = gen_images.new_full((gen_images.shape[0], 1), fake_label)
        realloss = disc_loss_criterion(disc_net(real_images), real)
        genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen)
        return torch.div(torch.add(realloss, genloss), 2)

    # create generator
    latent_size = 64
    gen_net = Generator(latent_shape=latent_size,
                        start_shape=(latent_size, 8, 8),
                        channels=[32, 16, 8, 1],
                        strides=[2, 2, 2, 1])
    gen_net.apply(normal_init)
    gen_net.conv.add_module("activation", torch.nn.Sigmoid())
    gen_net = gen_net.to(device)
    gen_opt = torch.optim.Adam(gen_net.parameters(),
                               learning_rate,
                               betas=betas)
    gen_loss_criterion = torch.nn.BCELoss()

    def generator_loss(gen_images):
        output = disc_net(gen_images)
        cats = output.new_full(output.shape, real_label)
        return gen_loss_criterion(output, cats)

    key_train_metric = None

    train_handlers = [
        StatsHandler(
            name="training_loss",
            output_transform=lambda x: {
                Keys.GLOSS: x[Keys.GLOSS],
                Keys.DLOSS: x[Keys.DLOSS]
            },
        ),
        TensorBoardStatsHandler(
            log_dir=root_dir,
            tag_name="training_loss",
            output_transform=lambda x: {
                Keys.GLOSS: x[Keys.GLOSS],
                Keys.DLOSS: x[Keys.DLOSS]
            },
        ),
        CheckpointSaver(save_dir=root_dir,
                        save_dict={
                            "g_net": gen_net,
                            "d_net": disc_net
                        },
                        save_interval=2,
                        epoch_level=True),
    ]

    disc_train_steps = 2
    num_epochs = 5

    trainer = GanTrainer(
        device,
        num_epochs,
        train_loader,
        gen_net,
        gen_opt,
        generator_loss,
        disc_net,
        disc_opt,
        discriminator_loss,
        d_train_steps=disc_train_steps,
        latent_shape=latent_size,
        key_train_metric=key_train_metric,
        train_handlers=train_handlers,
    )
    trainer.run()

    return trainer.state
Ejemplo n.º 29
0
def run_training_test(root_dir, device="cuda:0", cachedataset=0):
    monai.config.print_config()
    images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])]
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            # resampling with align_corners=True or dtype=float64 will generate
            # slight different results between PyTorch 1.5 an 1.6
            Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
            ScaleIntensityd(keys="img"),
            RandCropByPosNegLabeld(
                keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
            ),
            RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=[0, 2]),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    train_transforms.set_random_state(1234)
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            # resampling with align_corners=True or dtype=float64 will generate
            # slight different results between PyTorch 1.5 an 1.6
            Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
            ScaleIntensityd(keys="img"),
            ToTensord(keys=["img", "seg"]),
        ]
    )

    # create a training data loader
    if cachedataset == 2:
        train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.8)
    elif cachedataset == 3:
        train_ds = monai.data.LMDBDataset(data=train_files, transform=train_transforms)
    else:
        train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    val_post_tran = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
    dice_metric = DiceMetric(include_background=True, reduction="mean")

    # create UNet, DiceLoss and Adam optimizer
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 5e-4)

    # start a typical PyTorch training
    val_interval = 2
    best_metric, best_metric_epoch = -1, -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter(log_dir=os.path.join(root_dir, "runs"))
    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    for epoch in range(6):
        print("-" * 10)
        print(f"Epoch {epoch + 1}/{6}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss:{loss.item():0.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch +1} average loss:{epoch_loss:0.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
                    sw_batch_size, roi_size = 4, (96, 96, 96)
                    val_outputs = val_post_tran(sliding_window_inference(val_images, roi_size, sw_batch_size, model))
                    value, not_nans = dice_metric(y_pred=val_outputs, y=val_labels)
                    metric_count += not_nans.item()
                    metric_sum += value.item() * not_nans.item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), model_filename)
                    print("saved new best metric model")
                print(
                    f"current epoch {epoch +1} current mean dice: {metric:0.4f} "
                    f"best mean dice: {best_metric:0.4f} at epoch {best_metric_epoch}"
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
    print(f"train completed, best_metric: {best_metric:0.4f}  at epoch: {best_metric_epoch}")
    writer.close()
    return epoch_loss_values, best_metric, best_metric_epoch
Ejemplo n.º 30
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    images = [
        "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz",
    ]
    # 2 binary labels for gender classification: man and woman
    labels = np.array(
        [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0])
    train_files = [{
        "img": img,
        "label": label
    } for img, label in zip(images[:10], labels[:10])]
    val_files = [{
        "img": img,
        "label": label
    } for img, label in zip(images[-10:], labels[-10:])]

    # define transforms for image
    train_transforms = Compose([
        LoadNiftid(keys=["img"]),
        AddChanneld(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=(96, 96, 96)),
        RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]),
        ToTensord(keys=["img"]),
    ])
    val_transforms = Compose([
        LoadNiftid(keys=["img"]),
        AddChanneld(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=(96, 96, 96)),
        ToTensord(keys=["img"]),
    ])

    # define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    check_loader = DataLoader(check_ds,
                              batch_size=2,
                              num_workers=4,
                              pin_memory=torch.cuda.is_available())
    check_data = monai.utils.misc.first(check_loader)
    print(check_data["img"].shape, check_data["label"])

    # create DenseNet121, CrossEntropyLoss and Adam optimizer
    net = monai.networks.nets.densenet.densenet121(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
    )
    loss = torch.nn.CrossEntropyLoss()
    lr = 1e-5
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.device("cuda:0")

    # Ignite trainer expects batch=(img, label) and returns output=loss at every iteration,
    # user can add output_transform to return other values, like: y_pred, y, etc.
    def prepare_batch(batch, device=None, non_blocking=False):

        return _prepare_batch((batch["img"], batch["label"]), device,
                              non_blocking)

    trainer = create_supervised_trainer(net,
                                        opt,
                                        loss,
                                        device,
                                        False,
                                        prepare_batch=prepare_batch)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    checkpoint_handler = ModelCheckpoint("./runs/",
                                         "net",
                                         n_saved=10,
                                         require_empty=False)
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  "net": net,
                                  "opt": opt
                              })

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't set metrics for trainer here, so just print loss, user can also customize print functions
    # and can use output_transform to convert engine.state.output if it's not loss value
    train_stats_handler = StatsHandler(name="trainer")
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    train_tensorboard_stats_handler = TensorBoardStatsHandler()
    train_tensorboard_stats_handler.attach(trainer)

    # set parameters for validation
    validation_every_n_epochs = 1

    metric_name = "Accuracy"
    # add evaluation metric to the evaluator engine
    val_metrics = {
        metric_name: Accuracy(),
        "AUC": ROCAUC(to_onehot_y=True, add_softmax=True)
    }
    # Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(net,
                                            val_metrics,
                                            device,
                                            True,
                                            prepare_batch=prepare_batch)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every epoch
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        output_transform=lambda x:
        None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add early stopping handler to evaluator
    early_stopper = EarlyStopping(
        patience=4,
        score_function=stopping_fn_from_metric(metric_name),
        trainer=trainer)
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                handler=early_stopper)

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=2,
                            num_workers=4,
                            pin_memory=torch.cuda.is_available())

    @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
    def run_validation(engine):
        evaluator.run(val_loader)

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=torch.cuda.is_available())

    train_epochs = 30
    state = trainer.run(train_loader, train_epochs)