示例#1
0
 def _sliding_window_processor(engine, batch):
     net.eval()
     with torch.no_grad():
         val_images, val_labels = batch[0].to(device), batch[1].to(device)
         seg_probs = sliding_window_inference(val_images, roi_size,
                                              sw_batch_size, net)
         return seg_probs, val_labels
示例#2
0
    def validation(epoch_iterator_val):
        model.eval()
        dice_vals = list()

        with torch.no_grad():
            for step, batch in enumerate(epoch_iterator_val):
                val_inputs, val_labels = (batch["image"].cuda(),
                                          batch["label"].cuda())
                val_outputs = sliding_window_inference(val_inputs,
                                                       (96, 96, 96), 4, model)
                val_labels_list = decollate_batch(val_labels)
                val_labels_convert = [
                    post_label(val_label_tensor)
                    for val_label_tensor in val_labels_list
                ]
                val_outputs_list = decollate_batch(val_outputs)
                val_output_convert = [
                    post_pred(val_pred_tensor)
                    for val_pred_tensor in val_outputs_list
                ]
                dice_metric(y_pred=val_output_convert, y=val_labels_convert)
                dice = dice_metric.aggregate().item()
                dice_vals.append(dice)
                epoch_iterator_val.set_description(
                    "Validate (%d / %d Steps) (dice=%2.5f)" %
                    (global_step, 10.0, dice))

            dice_metric.reset()

        mean_dice_val = np.mean(dice_vals)
        return mean_dice_val
 def _sliding_window_processor(_engine, batch):
     net.eval()
     img, seg, meta_data = batch
     with torch.no_grad():
         seg_probs = sliding_window_inference(img.to(device), roi_size,
                                              sw_batch_size, net)
         return predict_segmentation(seg_probs)
示例#4
0
    def test_sliding_window_default(self, image_shape, roi_shape,
                                    sw_batch_size, overlap, mode, device):
        n_total = np.prod(image_shape)
        if mode == "constant":
            inputs = torch.arange(n_total,
                                  dtype=torch.float).reshape(*image_shape)
        else:
            inputs = torch.ones(*image_shape, dtype=torch.float)
        if device.type == "cuda" and not torch.cuda.is_available():
            device = torch.device("cpu:0")

        def compute(data):
            return data + 1

        if mode == "constant":
            expected_val = np.arange(
                n_total, dtype=np.float32).reshape(*image_shape) + 1.0
        else:
            expected_val = np.ones(image_shape, dtype=np.float32) + 1.0
        result = sliding_window_inference(inputs.to(device),
                                          roi_shape,
                                          sw_batch_size,
                                          compute,
                                          overlap,
                                          mode=mode)
        np.testing.assert_string_equal(device.type, result.device.type)
        np.testing.assert_allclose(result.cpu().numpy(), expected_val)

        result = SlidingWindowInferer(roi_shape, sw_batch_size, overlap,
                                      mode)(inputs.to(device), compute)
        np.testing.assert_string_equal(device.type, result.device.type)
        np.testing.assert_allclose(result.cpu().numpy(), expected_val)
示例#5
0
    def test_args_kwargs(self):
        device = "cuda" if torch.cuda.is_available() else "cpu:0"
        inputs = torch.ones((1, 1, 3, 3)).to(device=device)
        t1 = torch.ones(1).to(device=device)
        t2 = torch.ones(1).to(device=device)
        roi_shape = (5, 5)
        sw_batch_size = 10

        def compute(data, test1, test2):
            return data + test1 + test2

        result = sliding_window_inference(
            inputs,
            roi_shape,
            sw_batch_size,
            compute,
            0.5,
            "constant",
            1.0,
            "constant",
            0.0,
            device,
            device,
            t1,
            test2=t2,
        )
        expected = np.ones((1, 1, 3, 3)) + 2.0
        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)

        result = SlidingWindowInferer(roi_shape,
                                      sw_batch_size,
                                      overlap=0.5,
                                      mode="constant",
                                      cval=-1)(inputs, compute, t1, test2=t2)
        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)
示例#6
0
    def test_cval(self):
        device = "cuda" if torch.cuda.is_available() else "cpu:0"
        inputs = torch.ones((1, 1, 3, 3)).to(device=device)
        roi_shape = (5, 5)
        sw_batch_size = 10

        def compute(data):
            return data + data.sum()

        result = sliding_window_inference(
            inputs,
            roi_shape,
            sw_batch_size,
            compute,
            overlap=0.5,
            padding_mode="constant",
            cval=-1,
            mode="constant",
            sigma_scale=1.0,
        )
        expected = np.ones((1, 1, 3, 3)) * -6.0
        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)

        result = SlidingWindowInferer(roi_shape,
                                      sw_batch_size,
                                      overlap=0.5,
                                      mode="constant",
                                      cval=-1)(inputs, compute)
        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)
示例#7
0
def evaluate(model_path, val_loader, plot_path=None):
    model.load_state_dict(torch.load(model_path))
    model.eval()
    with torch.no_grad():
        for i, val_data in enumerate(val_loader):
            roi_size = (160, 160, 160)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(
                val_data['image'].to(device), roi_size, sw_batch_size, model)
            # plot the slice [:, :, 80]
            plt.figure('check', (18, 6))
            plt.subplot(1, 3, 1)
            plt.title(f"image {str(i)}")
            plt.imshow(val_data['image'][0, 0, :, :, 80], cmap='gray')
            plt.subplot(1, 3, 2)
            plt.title(f"label {str(i)}")
            plt.imshow(val_data['label'][0, 0, :, :, 80])
            plt.subplot(1, 3, 3)
            plt.title(f"output {str(i)}")
            plt.imshow(
                torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, 80])
            if plot_path is not None:
                fig_path = os.path.join(plot_path,
                                        "val_{}_evaluation.png".format(i))
                plt.savefig(fig_path)
                print("Saved {}".format(fig_path))
            else:
                plt.show()
def run_inference_test(root_dir, device=torch.device("cuda:0")):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        # resampling with align_corners=True or dtype=float64 will generate
        # slight different results between PyTorch 1.5 an 1.6
        Spacingd(keys=["img", "seg"],
                 pixdim=[1.2, 0.8, 0.7],
                 mode=["bilinear", "nearest"],
                 dtype=np.float32),
        ScaleIntensityd(keys=["img", "seg"]),
        ToTensord(keys=["img", "seg"]),
    ])
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inferene need to input 1 image in every iteration
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    dice_metric = DiceMetric(include_background=True,
                             to_onehot_y=False,
                             sigmoid=True,
                             reduction="mean")

    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    model.load_state_dict(torch.load(model_filename))
    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        # resampling with align_corners=True or dtype=float64 will generate
        # slight different results between PyTorch 1.5 an 1.6
        saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"),
                           dtype=np.float32)
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            sw_batch_size, roi_size = 4, (96, 96, 96)
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            value = dice_metric(y_pred=val_outputs, y=val_labels)
            not_nans = dice_metric.not_nans.item()
            metric_count += not_nans
            metric_sum += value.item() * not_nans
            val_outputs = (val_outputs.sigmoid() >= 0.5).float()
            saver.save_batch(val_outputs, val_data["img_meta_dict"])
        metric = metric_sum / metric_count
    return metric
示例#9
0
def run_inference_test(root_dir, device=torch.device("cuda:0")):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys=["img", "seg"]),
        ToTensord(keys=["img", "seg"]),
    ])
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inferene need to input 1 image in every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate,
                            pin_memory=torch.cuda.is_available())

    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    model.load_state_dict(torch.load(model_filename))
    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"),
                           dtype=int)
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            sw_batch_size, roi_size = 4, (96, 96, 96)
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            value = compute_meandice(y_pred=val_outputs,
                                     y=val_labels,
                                     include_background=True,
                                     to_onehot_y=False,
                                     add_sigmoid=True)
            metric_count += len(value)
            metric_sum += value.sum().item()
            val_outputs = (val_outputs.sigmoid() >= 0.5).float()
            saver.save_batch(
                val_outputs, {
                    "filename_or_obj": val_data["img.filename_or_obj"],
                    "affine": val_data["img.affine"]
                })
        metric = metric_sum / metric_count
    return metric
示例#10
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
        Image.fromarray(im.astype("uint8")).save(os.path.join(tempdir, f"img{i:d}.png"))
        Image.fromarray(seg.astype("uint8")).save(os.path.join(tempdir, f"seg{i:d}.png"))

    images = sorted(glob(os.path.join(tempdir, "img*.png")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.png")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            AddChanneld(keys=["img", "seg"]),
            ScaleIntensityd(keys="img"),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
    dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(torch.load("best_metric_model_segmentation2d_dict.pth"))

    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        saver = PNGSaver(output_dir="./output")
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            value = dice_metric(y_pred=val_outputs, y=val_labels)
            metric_count += len(value)
            metric_sum += value.item() * len(value)
            val_outputs = val_outputs.sigmoid() >= 0.5
            saver.save_batch(val_outputs)
        metric = metric_sum / metric_count
        print("evaluation metric:", metric)
def run_inference_test(root_dir, device="cuda:0"):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            EnsureChannelFirstd(keys=["img", "seg"]),
            # resampling with align_corners=True or dtype=float64 will generate
            # slight different results between PyTorch 1.5 an 1.6
            Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
            ScaleIntensityd(keys="img"),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inference need to input 1 image in every iteration
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)

    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    model.load_state_dict(torch.load(model_filename))
    with eval_mode(model):
        # resampling with align_corners=True or dtype=float64 will generate
        # slight different results between PyTorch 1.5 an 1.6
        saver = SaveImage(
            output_dir=os.path.join(root_dir, "output"),
            dtype=np.float32,
            output_ext=".nii.gz",
            output_postfix="seg",
            mode="bilinear",
        )
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            sw_batch_size, roi_size = 4, (96, 96, 96)
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            # decollate prediction into a list
            val_outputs = [val_post_tran(i) for i in decollate_batch(val_outputs)]
            val_meta = decollate_batch(val_data[PostFix.meta("img")])
            # compute metrics
            dice_metric(y_pred=val_outputs, y=val_labels)
            for img, meta in zip(val_outputs, val_meta):  # save a decollated batch of files
                saver(img, meta)

    return dice_metric.aggregate().item()
示例#12
0
 def sliding_window_inference(self, image):
     return sliding_window_inference(
         inputs=image,
         roi_size=self.patch_size,
         sw_batch_size=self.args.val_batch_size,
         predictor=self.model,
         overlap=self.args.overlap,
         mode=self.args.blend,
     )
 def _sliding_window_processor(_engine, batch):
     img = batch[0]  # first item from ImageDataset is the input image
     with eval_mode(net):
         seg_probs = sliding_window_inference(img.to(device),
                                              roi_size,
                                              sw_batch_size,
                                              net,
                                              device=device)
         return predict_segmentation(seg_probs)
 def _sliding_window_processor(_engine, batch):
     img, seg, meta_data = batch
     with eval_mode(net):
         seg_probs = sliding_window_inference(img.to(device),
                                              roi_size,
                                              sw_batch_size,
                                              net,
                                              device=device)
         return predict_segmentation(seg_probs)
示例#15
0
 def _sliding_window_processor(engine, batch):
     net.eval()
     with torch.no_grad():
         val_images, val_labels = batch["img"].to(device), batch["seg"].to(device)
         seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
         seg_probs = [post_trans(i) for i in decollate_batch(seg_probs)]
         val_data = decollate_batch(batch["img_meta_dict"])
         for seg_prob, data in zip(seg_probs, val_data):
             save_image(seg_prob, data)
         return seg_probs, val_labels
示例#16
0
    def test_step(self, batch, batch_idx):
        images, labels = batch["IMAGE"], batch["SEGM"]
        roi_size = self.userinputs['roi_size']
        sw_batch_size = self.userinputs['sw_batch_size']
        outputs = sliding_window_inference(images, roi_size, sw_batch_size,
                                           self.forward)

        # TODO: conditions that no masks for testing image----------------------------------------
        value = self.val_metric(y_pred=outputs, y=labels)
        return {"test_dice": value}
示例#17
0
    def test_sliding_window_default(self, image_shape, roi_shape, sw_batch_size, overlap, mode):
        inputs = torch.ones(*image_shape)
        device = torch.device("cpu:0")

        def compute(data):
            return data + 1

        result = sliding_window_inference(inputs.to(device), roi_shape, sw_batch_size, compute, overlap, mode=mode)
        expected_val = np.ones(image_shape, dtype=np.float32) + 1
        self.assertTrue(np.allclose(result.numpy(), expected_val))
def main(tempdir):
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    segtrans = Compose([AddChannel(), ToTensor()])
    val_ds = ImageDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)
    # sliding window inference for one image at every iteration
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(torch.load("best_metric_model_segmentation3d_array.pth"))
    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        saver = NiftiSaver(output_dir="./output")
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            val_outputs = post_trans(val_outputs)
            value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
            metric_count += len(value)
            metric_sum += value.item() * len(value)
            saver.save_batch(val_outputs, val_data[2])
        metric = metric_sum / metric_count
        print("evaluation metric:", metric)
示例#19
0
    def test_default_device(self):
        device = "cuda" if torch.cuda.is_available() else "cpu:0"
        inputs = torch.ones((1, 3, 16, 15, 7)).to(device=device)
        roi_shape = (4, 10, 7)
        sw_batch_size = 10

        def compute(data):
            return data + 1

        result = sliding_window_inference(inputs, roi_shape, sw_batch_size, compute)
        np.testing.assert_string_equal(inputs.device.type, result.device.type)
        expected_val = np.ones((1, 3, 16, 15, 7), dtype=np.float32) + 1
        np.testing.assert_allclose(result.cpu().numpy(), expected_val)
示例#20
0
    def test_sw_device(self):
        inputs = torch.ones((1, 3, 16, 15, 7)).to(device="cpu")
        roi_shape = (4, 10, 7)
        sw_batch_size = 10

        def compute(data):
            self.assertEqual(data.device.type, "cuda")
            return data + torch.tensor(1, device="cuda")

        result = sliding_window_inference(inputs, roi_shape, sw_batch_size, compute, sw_device="cuda")
        np.testing.assert_string_equal(inputs.device.type, result.device.type)
        expected_val = np.ones((1, 3, 16, 15, 7), dtype=np.float32) + 1
        np.testing.assert_allclose(result.cpu().numpy(), expected_val)
示例#21
0
 def predict(self, X, model):  # noqa: N803
     y_pred = []
     model, _ = model  # drop the optimizer
     model.eval()
     with torch.no_grad():
         for inputs, metadata in X:
             inputs = inputs.to(self.device)
             roi_size = (96, 96, 96)
             sw_batch_size = 4
             outputs = sliding_window_inference(inputs, roi_size, sw_batch_size, model)
             outputs = post_trans(outputs)
             y_pred.append((outputs, metadata))
     return y_pred
 def forward(self, batch):
     if self.is_train_loader:
         output = {self.output_key: self.model(batch[self.input_key])}
     elif self.is_valid_loader:
         roi_size = (96, 96, 96)
         sw_batch_size = 4
         output = {
             self.output_key:
             sliding_window_inference(batch[self.input_key], roi_size,
                                      sw_batch_size, self.model)
         }
     elif self.is_infer_loader:
         roi_size = (96, 96, 96)
         sw_batch_size = 4
         batch = self._batch2device(batch, self.device)
         output = {
             self.output_key:
             sliding_window_inference(batch[self.input_key], roi_size,
                                      sw_batch_size, self.model)
         }
         output = {**output, **batch}
     return output
示例#23
0
 def validation_step(self, batch, batch_idx):
     images, labels = batch["image"], batch["label"]
     roi_size = PATCH_SIZE
     sw_batch_size = 1
     outputs = sliding_window_inference(images, roi_size, sw_batch_size,
                                        self.forward)
     loss = self.loss_function(outputs, labels)
     outputs = self.post_pred(outputs)
     labels = self.post_label(labels)
     value = compute_meandice(y_pred=outputs,
                              y=labels,
                              include_background=False)
     return {"val_loss": loss, "val_dice": value}
示例#24
0
def main(tempdir):
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
        Image.fromarray((im * 255).astype("uint8")).save(os.path.join(tempdir, f"img{i:d}.png"))
        Image.fromarray((seg * 255).astype("uint8")).save(os.path.join(tempdir, f"seg{i:d}.png"))

    images = sorted(glob(os.path.join(tempdir, "img*.png")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.png")))

    # define transforms for image and segmentation
    imtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])
    segtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])
    val_ds = ArrayDataset(images, imtrans, segs, segtrans)
    # sliding window inference for one image at every iteration
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
    post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    saver = SaveImage(output_dir="./output", output_ext=".png", output_postfix="seg")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(torch.load("best_metric_model_segmentation2d_array.pth"))
    model.eval()
    with torch.no_grad():
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
            val_labels = decollate_batch(val_labels)
            # compute metric for current iteration
            dice_metric(y_pred=val_outputs, y=val_labels)
            for val_output in val_outputs:
                saver(val_output)
        # aggregate the final mean dice result
        print("evaluation metric:", dice_metric.aggregate().item())
        # reset the status
        dice_metric.reset()
示例#25
0
def main():
    images = sorted(glob(os.path.join(IMAGE_FOLDER, "case*.nii.gz")))
    val_files = [{"img": img} for img in images]

    # define transforms for image and segmentation
    infer_transforms = Compose([
        LoadNiftid("img"),
        AddChanneld("img"),
        Orientationd(
            "img",
            "SPL"),  # coplenet works on the plane defined by the last two axes
        ToTensord("img"),
    ])
    test_ds = monai.data.Dataset(data=val_files, transform=infer_transforms)
    # sliding window inference need to input 1 image in every iteration
    data_loader = torch.utils.data.DataLoader(
        test_ds,
        batch_size=1,
        num_workers=0,
        pin_memory=torch.cuda.is_available())

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CopleNet().to(device)

    model.load_state_dict(torch.load(MODEL_FILE)["model_state_dict"])
    model.eval()

    with torch.no_grad():
        saver = NiftiSaver(output_dir=OUTPUT_FOLDER)
        for idx, val_data in enumerate(data_loader):
            print(f"Inference on {idx+1} of {len(data_loader)}")
            val_images = val_data["img"].to(device)
            # define sliding window size and batch size for windows inference
            slice_shape = np.ceil(np.asarray(val_images.shape[3:]) / 32) * 32
            roi_size = (20, int(slice_shape[0]), int(slice_shape[1]))
            sw_batch_size = 2
            val_outputs = sliding_window_inference(val_images,
                                                   roi_size,
                                                   sw_batch_size,
                                                   model,
                                                   0.0,
                                                   padding_mode="circular")
            # val_outputs = (val_outputs.sigmoid() >= 0.5).float()
            val_outputs = val_outputs.argmax(dim=1, keepdim=True)
            saver.save_batch(val_outputs, val_data["img_meta_dict"])
    def test_sliding_window_default(self, image_shape, roi_shape,
                                    sw_batch_size, overlap, mode, device):
        inputs = torch.ones(*image_shape)
        if device.type == "cuda" and not torch.cuda.is_available():
            device = torch.device("cpu:0")

        def compute(data):
            return data + 1

        result = sliding_window_inference(inputs.to(device),
                                          roi_shape,
                                          sw_batch_size,
                                          compute,
                                          overlap,
                                          mode=mode)
        np.testing.assert_string_equal(device.type, result.device.type)
        expected_val = np.ones(image_shape, dtype=np.float32) + 1
        np.testing.assert_allclose(result.numpy(), expected_val)
示例#27
0
                def plot_dice(images_loader):

                    metric_sum = 0.0
                    metric_count = 0
                    val_images = None
                    val_labels = None
                    val_outputs = None
                    for data in images_loader:
                        val_images, val_labels = data["image"].cuda(), data["label"].cuda()
                        roi_size = opt.patch_size
                        sw_batch_size = 4
                        val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
                        val_outputs = post_trans(val_outputs)
                        value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
                        metric_count += len(value)
                        metric_sum += value.item() * len(value)
                    metric = metric_sum / metric_count
                    metric_values.append(metric)
                    return metric, val_images, val_labels, val_outputs
                def plot_dice(images_loader):

                    val_images = None
                    val_labels = None
                    val_outputs = None
                    for data in images_loader:
                        val_images, val_labels = data["image"].cuda(), data["label"].cuda()
                        roi_size = opt.patch_size
                        sw_batch_size = 4
                        val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
                        val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                        dice_metric(y_pred=val_outputs, y=val_labels)

                    # aggregate the final mean dice result
                    metric = dice_metric.aggregate().item()
                    # reset the status for next validation round
                    dice_metric.reset()

                    return metric, val_images, val_labels, val_outputs
def predict_mask(model, images):
    image_array = []
    for image in images:
        image_array.append(image.pixel_array)
    image_array = np.expand_dims(
        np.transpose(np.array(image_array).astype('float32')), (0, 1))
    data_transforms = Compose([AddChannel(), NormalizeIntensity(), ToTensor()])

    dataset = monai.data.Dataset(data=image_array, transform=data_transforms)

    print(dataset[0].shape)
    test_mask = sliding_window_inference(dataset[0],
                                         roi_size=[128, 128, 16],
                                         sw_batch_size=1,
                                         predictor=model)
    test_mask = test_mask.argmax(1).detach().cpu().numpy()
    test_mask = np.transpose(np.squeeze(test_mask, 0))
    test_mask = test_mask.astype('uint8')
    test_mask = np.asarray(test_mask, order='C')
    return test_mask
示例#30
0
def evaluate(model, val_loader, dice_metric, dice_metric_batch, post_trans):
    model.eval()
    with torch.no_grad():
        for val_data in val_loader:
            with torch.cuda.amp.autocast():
                val_outputs = sliding_window_inference(
                    inputs=val_data["image"], roi_size=(240, 240, 160), sw_batch_size=4, predictor=model, overlap=0.6
                )
            val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
            dice_metric(y_pred=val_outputs, y=val_data["label"])
            dice_metric_batch(y_pred=val_outputs, y=val_data["label"])

        metric = dice_metric.aggregate().item()
        metric_batch = dice_metric_batch.aggregate()
        metric_tc = metric_batch[0].item()
        metric_wt = metric_batch[1].item()
        metric_et = metric_batch[2].item()
        dice_metric.reset()
        dice_metric_batch.reset()

    return metric, metric_tc, metric_wt, metric_et