def _sliding_window_processor(engine, batch): net.eval() with torch.no_grad(): val_images, val_labels = batch[0].to(device), batch[1].to(device) seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) return seg_probs, val_labels
def validation(epoch_iterator_val): model.eval() dice_vals = list() with torch.no_grad(): for step, batch in enumerate(epoch_iterator_val): val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda()) val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model) val_labels_list = decollate_batch(val_labels) val_labels_convert = [ post_label(val_label_tensor) for val_label_tensor in val_labels_list ] val_outputs_list = decollate_batch(val_outputs) val_output_convert = [ post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list ] dice_metric(y_pred=val_output_convert, y=val_labels_convert) dice = dice_metric.aggregate().item() dice_vals.append(dice) epoch_iterator_val.set_description( "Validate (%d / %d Steps) (dice=%2.5f)" % (global_step, 10.0, dice)) dice_metric.reset() mean_dice_val = np.mean(dice_vals) return mean_dice_val
def _sliding_window_processor(_engine, batch): net.eval() img, seg, meta_data = batch with torch.no_grad(): seg_probs = sliding_window_inference(img.to(device), roi_size, sw_batch_size, net) return predict_segmentation(seg_probs)
def test_sliding_window_default(self, image_shape, roi_shape, sw_batch_size, overlap, mode, device): n_total = np.prod(image_shape) if mode == "constant": inputs = torch.arange(n_total, dtype=torch.float).reshape(*image_shape) else: inputs = torch.ones(*image_shape, dtype=torch.float) if device.type == "cuda" and not torch.cuda.is_available(): device = torch.device("cpu:0") def compute(data): return data + 1 if mode == "constant": expected_val = np.arange( n_total, dtype=np.float32).reshape(*image_shape) + 1.0 else: expected_val = np.ones(image_shape, dtype=np.float32) + 1.0 result = sliding_window_inference(inputs.to(device), roi_shape, sw_batch_size, compute, overlap, mode=mode) np.testing.assert_string_equal(device.type, result.device.type) np.testing.assert_allclose(result.cpu().numpy(), expected_val) result = SlidingWindowInferer(roi_shape, sw_batch_size, overlap, mode)(inputs.to(device), compute) np.testing.assert_string_equal(device.type, result.device.type) np.testing.assert_allclose(result.cpu().numpy(), expected_val)
def test_args_kwargs(self): device = "cuda" if torch.cuda.is_available() else "cpu:0" inputs = torch.ones((1, 1, 3, 3)).to(device=device) t1 = torch.ones(1).to(device=device) t2 = torch.ones(1).to(device=device) roi_shape = (5, 5) sw_batch_size = 10 def compute(data, test1, test2): return data + test1 + test2 result = sliding_window_inference( inputs, roi_shape, sw_batch_size, compute, 0.5, "constant", 1.0, "constant", 0.0, device, device, t1, test2=t2, ) expected = np.ones((1, 1, 3, 3)) + 2.0 np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4) result = SlidingWindowInferer(roi_shape, sw_batch_size, overlap=0.5, mode="constant", cval=-1)(inputs, compute, t1, test2=t2) np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)
def test_cval(self): device = "cuda" if torch.cuda.is_available() else "cpu:0" inputs = torch.ones((1, 1, 3, 3)).to(device=device) roi_shape = (5, 5) sw_batch_size = 10 def compute(data): return data + data.sum() result = sliding_window_inference( inputs, roi_shape, sw_batch_size, compute, overlap=0.5, padding_mode="constant", cval=-1, mode="constant", sigma_scale=1.0, ) expected = np.ones((1, 1, 3, 3)) * -6.0 np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4) result = SlidingWindowInferer(roi_shape, sw_batch_size, overlap=0.5, mode="constant", cval=-1)(inputs, compute) np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)
def evaluate(model_path, val_loader, plot_path=None): model.load_state_dict(torch.load(model_path)) model.eval() with torch.no_grad(): for i, val_data in enumerate(val_loader): roi_size = (160, 160, 160) sw_batch_size = 4 val_outputs = sliding_window_inference( val_data['image'].to(device), roi_size, sw_batch_size, model) # plot the slice [:, :, 80] plt.figure('check', (18, 6)) plt.subplot(1, 3, 1) plt.title(f"image {str(i)}") plt.imshow(val_data['image'][0, 0, :, :, 80], cmap='gray') plt.subplot(1, 3, 2) plt.title(f"label {str(i)}") plt.imshow(val_data['label'][0, 0, :, :, 80]) plt.subplot(1, 3, 3) plt.title(f"output {str(i)}") plt.imshow( torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, 80]) if plot_path is not None: fig_path = os.path.join(plot_path, "val_{}_evaluation.png".format(i)) plt.savefig(fig_path) print("Saved {}".format(fig_path)) else: plt.show()
def run_inference_test(root_dir, device=torch.device("cuda:0")): images = sorted(glob(os.path.join(root_dir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd(keys=["img", "seg"]), ToTensord(keys=["img", "seg"]), ]) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inferene need to input 1 image in every iteration val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") model = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model_filename = os.path.join(root_dir, "best_metric_model.pth") model.load_state_dict(torch.load(model_filename)) model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"), dtype=np.float32) for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = dice_metric(y_pred=val_outputs, y=val_labels) not_nans = dice_metric.not_nans.item() metric_count += not_nans metric_sum += value.item() * not_nans val_outputs = (val_outputs.sigmoid() >= 0.5).float() saver.save_batch(val_outputs, val_data["img_meta_dict"]) metric = metric_sum / metric_count return metric
def run_inference_test(root_dir, device=torch.device("cuda:0")): images = sorted(glob(os.path.join(root_dir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys=["img", "seg"]), ToTensord(keys=["img", "seg"]), ]) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inferene need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) model = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model_filename = os.path.join(root_dir, "best_metric_model.pth") model.load_state_dict(torch.load(model_filename)) model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"), dtype=int) for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() val_outputs = (val_outputs.sigmoid() >= 0.5).float() saver.save_batch( val_outputs, { "filename_or_obj": val_data["img.filename_or_obj"], "affine": val_data["img.affine"] }) metric = metric_sum / metric_count return metric
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_2d(128, 128, num_seg_classes=1) Image.fromarray(im.astype("uint8")).save(os.path.join(tempdir, f"img{i:d}.png")) Image.fromarray(seg.astype("uint8")).save(os.path.join(tempdir, f"seg{i:d}.png")) images = sorted(glob(os.path.join(tempdir, "img*.png"))) segs = sorted(glob(os.path.join(tempdir, "seg*.png"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose( [ LoadImaged(keys=["img", "seg"]), AddChanneld(keys=["img", "seg"]), ScaleIntensityd(keys="img"), ToTensord(keys=["img", "seg"]), ] ) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet( dimensions=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict(torch.load("best_metric_model_segmentation2d_dict.pth")) model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 saver = PNGSaver(output_dir="./output") for val_data in val_loader: val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) val_outputs = val_outputs.sigmoid() >= 0.5 saver.save_batch(val_outputs) metric = metric_sum / metric_count print("evaluation metric:", metric)
def run_inference_test(root_dir, device="cuda:0"): images = sorted(glob(os.path.join(root_dir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose( [ LoadImaged(keys=["img", "seg"]), EnsureChannelFirstd(keys=["img", "seg"]), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd(keys="img"), ToTensord(keys=["img", "seg"]), ] ) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) model = UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model_filename = os.path.join(root_dir, "best_metric_model.pth") model.load_state_dict(torch.load(model_filename)) with eval_mode(model): # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 saver = SaveImage( output_dir=os.path.join(root_dir, "output"), dtype=np.float32, output_ext=".nii.gz", output_postfix="seg", mode="bilinear", ) for val_data in val_loader: val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) # decollate prediction into a list val_outputs = [val_post_tran(i) for i in decollate_batch(val_outputs)] val_meta = decollate_batch(val_data[PostFix.meta("img")]) # compute metrics dice_metric(y_pred=val_outputs, y=val_labels) for img, meta in zip(val_outputs, val_meta): # save a decollated batch of files saver(img, meta) return dice_metric.aggregate().item()
def sliding_window_inference(self, image): return sliding_window_inference( inputs=image, roi_size=self.patch_size, sw_batch_size=self.args.val_batch_size, predictor=self.model, overlap=self.args.overlap, mode=self.args.blend, )
def _sliding_window_processor(_engine, batch): img = batch[0] # first item from ImageDataset is the input image with eval_mode(net): seg_probs = sliding_window_inference(img.to(device), roi_size, sw_batch_size, net, device=device) return predict_segmentation(seg_probs)
def _sliding_window_processor(_engine, batch): img, seg, meta_data = batch with eval_mode(net): seg_probs = sliding_window_inference(img.to(device), roi_size, sw_batch_size, net, device=device) return predict_segmentation(seg_probs)
def _sliding_window_processor(engine, batch): net.eval() with torch.no_grad(): val_images, val_labels = batch["img"].to(device), batch["seg"].to(device) seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) seg_probs = [post_trans(i) for i in decollate_batch(seg_probs)] val_data = decollate_batch(batch["img_meta_dict"]) for seg_prob, data in zip(seg_probs, val_data): save_image(seg_prob, data) return seg_probs, val_labels
def test_step(self, batch, batch_idx): images, labels = batch["IMAGE"], batch["SEGM"] roi_size = self.userinputs['roi_size'] sw_batch_size = self.userinputs['sw_batch_size'] outputs = sliding_window_inference(images, roi_size, sw_batch_size, self.forward) # TODO: conditions that no masks for testing image---------------------------------------- value = self.val_metric(y_pred=outputs, y=labels) return {"test_dice": value}
def test_sliding_window_default(self, image_shape, roi_shape, sw_batch_size, overlap, mode): inputs = torch.ones(*image_shape) device = torch.device("cpu:0") def compute(data): return data + 1 result = sliding_window_inference(inputs.to(device), roi_shape, sw_batch_size, compute, overlap, mode=mode) expected_val = np.ones(image_shape, dtype=np.float32) + 1 self.assertTrue(np.allclose(result.numpy(), expected_val))
def main(tempdir): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) # define transforms for image and segmentation imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) segtrans = Compose([AddChannel(), ToTensor()]) val_ds = ImageDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) # sliding window inference for one image at every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) dice_metric = DiceMetric(include_background=True, reduction="mean") post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict(torch.load("best_metric_model_segmentation3d_array.pth")) model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 saver = NiftiSaver(output_dir="./output") for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = post_trans(val_outputs) value, _ = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) saver.save_batch(val_outputs, val_data[2]) metric = metric_sum / metric_count print("evaluation metric:", metric)
def test_default_device(self): device = "cuda" if torch.cuda.is_available() else "cpu:0" inputs = torch.ones((1, 3, 16, 15, 7)).to(device=device) roi_shape = (4, 10, 7) sw_batch_size = 10 def compute(data): return data + 1 result = sliding_window_inference(inputs, roi_shape, sw_batch_size, compute) np.testing.assert_string_equal(inputs.device.type, result.device.type) expected_val = np.ones((1, 3, 16, 15, 7), dtype=np.float32) + 1 np.testing.assert_allclose(result.cpu().numpy(), expected_val)
def test_sw_device(self): inputs = torch.ones((1, 3, 16, 15, 7)).to(device="cpu") roi_shape = (4, 10, 7) sw_batch_size = 10 def compute(data): self.assertEqual(data.device.type, "cuda") return data + torch.tensor(1, device="cuda") result = sliding_window_inference(inputs, roi_shape, sw_batch_size, compute, sw_device="cuda") np.testing.assert_string_equal(inputs.device.type, result.device.type) expected_val = np.ones((1, 3, 16, 15, 7), dtype=np.float32) + 1 np.testing.assert_allclose(result.cpu().numpy(), expected_val)
def predict(self, X, model): # noqa: N803 y_pred = [] model, _ = model # drop the optimizer model.eval() with torch.no_grad(): for inputs, metadata in X: inputs = inputs.to(self.device) roi_size = (96, 96, 96) sw_batch_size = 4 outputs = sliding_window_inference(inputs, roi_size, sw_batch_size, model) outputs = post_trans(outputs) y_pred.append((outputs, metadata)) return y_pred
def forward(self, batch): if self.is_train_loader: output = {self.output_key: self.model(batch[self.input_key])} elif self.is_valid_loader: roi_size = (96, 96, 96) sw_batch_size = 4 output = { self.output_key: sliding_window_inference(batch[self.input_key], roi_size, sw_batch_size, self.model) } elif self.is_infer_loader: roi_size = (96, 96, 96) sw_batch_size = 4 batch = self._batch2device(batch, self.device) output = { self.output_key: sliding_window_inference(batch[self.input_key], roi_size, sw_batch_size, self.model) } output = {**output, **batch} return output
def validation_step(self, batch, batch_idx): images, labels = batch["image"], batch["label"] roi_size = PATCH_SIZE sw_batch_size = 1 outputs = sliding_window_inference(images, roi_size, sw_batch_size, self.forward) loss = self.loss_function(outputs, labels) outputs = self.post_pred(outputs) labels = self.post_label(labels) value = compute_meandice(y_pred=outputs, y=labels, include_background=False) return {"val_loss": loss, "val_dice": value}
def main(tempdir): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_2d(128, 128, num_seg_classes=1) Image.fromarray((im * 255).astype("uint8")).save(os.path.join(tempdir, f"img{i:d}.png")) Image.fromarray((seg * 255).astype("uint8")).save(os.path.join(tempdir, f"seg{i:d}.png")) images = sorted(glob(os.path.join(tempdir, "img*.png"))) segs = sorted(glob(os.path.join(tempdir, "seg*.png"))) # define transforms for image and segmentation imtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()]) segtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()]) val_ds = ArrayDataset(images, imtrans, segs, segtrans) # sliding window inference for one image at every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) saver = SaveImage(output_dir="./output", output_ext=".png", output_postfix="seg") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet( spatial_dims=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict(torch.load("best_metric_model_segmentation2d_array.pth")) model.eval() with torch.no_grad(): for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] val_labels = decollate_batch(val_labels) # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) for val_output in val_outputs: saver(val_output) # aggregate the final mean dice result print("evaluation metric:", dice_metric.aggregate().item()) # reset the status dice_metric.reset()
def main(): images = sorted(glob(os.path.join(IMAGE_FOLDER, "case*.nii.gz"))) val_files = [{"img": img} for img in images] # define transforms for image and segmentation infer_transforms = Compose([ LoadNiftid("img"), AddChanneld("img"), Orientationd( "img", "SPL"), # coplenet works on the plane defined by the last two axes ToTensord("img"), ]) test_ds = monai.data.Dataset(data=val_files, transform=infer_transforms) # sliding window inference need to input 1 image in every iteration data_loader = torch.utils.data.DataLoader( test_ds, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available()) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = CopleNet().to(device) model.load_state_dict(torch.load(MODEL_FILE)["model_state_dict"]) model.eval() with torch.no_grad(): saver = NiftiSaver(output_dir=OUTPUT_FOLDER) for idx, val_data in enumerate(data_loader): print(f"Inference on {idx+1} of {len(data_loader)}") val_images = val_data["img"].to(device) # define sliding window size and batch size for windows inference slice_shape = np.ceil(np.asarray(val_images.shape[3:]) / 32) * 32 roi_size = (20, int(slice_shape[0]), int(slice_shape[1])) sw_batch_size = 2 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model, 0.0, padding_mode="circular") # val_outputs = (val_outputs.sigmoid() >= 0.5).float() val_outputs = val_outputs.argmax(dim=1, keepdim=True) saver.save_batch(val_outputs, val_data["img_meta_dict"])
def test_sliding_window_default(self, image_shape, roi_shape, sw_batch_size, overlap, mode, device): inputs = torch.ones(*image_shape) if device.type == "cuda" and not torch.cuda.is_available(): device = torch.device("cpu:0") def compute(data): return data + 1 result = sliding_window_inference(inputs.to(device), roi_shape, sw_batch_size, compute, overlap, mode=mode) np.testing.assert_string_equal(device.type, result.device.type) expected_val = np.ones(image_shape, dtype=np.float32) + 1 np.testing.assert_allclose(result.numpy(), expected_val)
def plot_dice(images_loader): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for data in images_loader: val_images, val_labels = data["image"].cuda(), data["label"].cuda() roi_size = opt.patch_size sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) val_outputs = post_trans(val_outputs) value, _ = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) metric = metric_sum / metric_count metric_values.append(metric) return metric, val_images, val_labels, val_outputs
def plot_dice(images_loader): val_images = None val_labels = None val_outputs = None for data in images_loader: val_images, val_labels = data["image"].cuda(), data["label"].cuda() roi_size = opt.patch_size sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] dice_metric(y_pred=val_outputs, y=val_labels) # aggregate the final mean dice result metric = dice_metric.aggregate().item() # reset the status for next validation round dice_metric.reset() return metric, val_images, val_labels, val_outputs
def predict_mask(model, images): image_array = [] for image in images: image_array.append(image.pixel_array) image_array = np.expand_dims( np.transpose(np.array(image_array).astype('float32')), (0, 1)) data_transforms = Compose([AddChannel(), NormalizeIntensity(), ToTensor()]) dataset = monai.data.Dataset(data=image_array, transform=data_transforms) print(dataset[0].shape) test_mask = sliding_window_inference(dataset[0], roi_size=[128, 128, 16], sw_batch_size=1, predictor=model) test_mask = test_mask.argmax(1).detach().cpu().numpy() test_mask = np.transpose(np.squeeze(test_mask, 0)) test_mask = test_mask.astype('uint8') test_mask = np.asarray(test_mask, order='C') return test_mask
def evaluate(model, val_loader, dice_metric, dice_metric_batch, post_trans): model.eval() with torch.no_grad(): for val_data in val_loader: with torch.cuda.amp.autocast(): val_outputs = sliding_window_inference( inputs=val_data["image"], roi_size=(240, 240, 160), sw_batch_size=4, predictor=model, overlap=0.6 ) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] dice_metric(y_pred=val_outputs, y=val_data["label"]) dice_metric_batch(y_pred=val_outputs, y=val_data["label"]) metric = dice_metric.aggregate().item() metric_batch = dice_metric_batch.aggregate() metric_tc = metric_batch[0].item() metric_wt = metric_batch[1].item() metric_et = metric_batch[2].item() dice_metric.reset() dice_metric_batch.reset() return metric, metric_tc, metric_wt, metric_et