Exemple #1
0
    def test_saved_3d_resize_content(self):
        with tempfile.TemporaryDirectory() as tempdir:

            saver = NiftiSaver(output_dir=tempdir,
                               output_postfix="seg",
                               output_ext=".nii.gz",
                               dtype=np.float32)

            meta_data = {
                "filename_or_obj":
                ["testfile" + str(i) + ".nii.gz" for i in range(8)],
                "spatial_shape": [(10, 10, 2)] * 8,
                "affine": [np.diag(np.ones(4)) * 5] * 8,
                "original_affine": [np.diag(np.ones(4)) * 1.0] * 8,
            }
            saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data)
            for i in range(8):
                filepath = os.path.join("testfile" + str(i),
                                        "testfile" + str(i) + "_seg.nii.gz")
                self.assertTrue(os.path.exists(os.path.join(tempdir,
                                                            filepath)))
Exemple #2
0
def main():
    images = sorted(glob(os.path.join(IMAGE_FOLDER, "case*.nii.gz")))
    val_files = [{"img": img} for img in images]

    # define transforms for image and segmentation
    infer_transforms = Compose(
        [
            LoadNiftid("img"),
            AddChanneld("img"),
            Orientationd("img", "SPL"),  # coplenet works on the plane defined by the last two axes
            ToTensord("img"),
        ]
    )
    test_ds = monai.data.Dataset(data=val_files, transform=infer_transforms)
    # sliding window inference need to input 1 image in every iteration
    data_loader = torch.utils.data.DataLoader(
        test_ds, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available()
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CopleNet().to(device)

    model.load_state_dict(torch.load(MODEL_FILE)["model_state_dict"])
    model.eval()

    with torch.no_grad():
        saver = NiftiSaver(output_dir=OUTPUT_FOLDER)
        for idx, val_data in enumerate(data_loader):
            print(f"Inference on {idx+1} of {len(data_loader)}")
            val_images = val_data["img"].to(device)
            # define sliding window size and batch size for windows inference
            slice_shape = np.ceil(np.asarray(val_images.shape[3:]) / 32) * 32
            roi_size = (20, int(slice_shape[0]), int(slice_shape[1]))
            sw_batch_size = 2
            val_outputs = sliding_window_inference(
                val_images, roi_size, sw_batch_size, model, 0.0, padding_mode="circular"
            )
            # val_outputs = (val_outputs.sigmoid() >= 0.5).float()
            val_outputs = val_outputs.argmax(dim=1, keepdim=True)
            saver.save_batch(val_outputs, val_data["img_meta_dict"])
    def test_saved_3d_no_resize_content(self):
        with tempfile.TemporaryDirectory() as tempdir:

            saver = NiftiSaver(output_dir=tempdir,
                               output_postfix="seg",
                               output_ext=".nii.gz",
                               dtype=np.float32,
                               resample=False)

            meta_data = {
                "filename_or_obj":
                ["testfile" + str(i) + ".nii.gz" for i in range(8)],
                "spatial_shape": [(10, 10, 2)] * 8,
                "affine": [np.diag(np.ones(4)) * 5] * 8,
                "original_affine": [np.diag(np.ones(4)) * 1.0] * 8,
            }
            saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data)
            for i in range(8):
                filepath = os.path.join(tempdir, "testfile" + str(i),
                                        "testfile" + str(i) + "_seg.nii.gz")
                img, _ = LoadImage("nibabelreader")(filepath)
                self.assertEqual(img.shape, (1, 2, 2, 8))
Exemple #4
0
    def test_saved_3d_resize_content(self):
        default_dir = os.path.join(".", "tempdir")
        shutil.rmtree(default_dir, ignore_errors=True)

        saver = NiftiSaver(output_dir=default_dir,
                           output_postfix="seg",
                           output_ext=".nii.gz",
                           dtype=np.float32)

        meta_data = {
            "filename_or_obj": ["testfile" + str(i) for i in range(8)],
            "spatial_shape": [(10, 10, 2)] * 8,
            "affine": [np.diag(np.ones(4)) * 5] * 8,
            "original_affine": [np.diag(np.ones(4)) * 1.0] * 8,
        }
        saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data)
        for i in range(8):
            filepath = os.path.join("testfile" + str(i),
                                    "testfile" + str(i) + "_seg.nii.gz")
            self.assertTrue(os.path.exists(os.path.join(default_dir,
                                                        filepath)))
        shutil.rmtree(default_dir)
def main():
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    segtrans = Compose([AddChannel(), ToTensor()])
    val_ds = NiftiDataset(images,
                          segs,
                          transform=imtrans,
                          seg_transform=segtrans,
                          image_only=False)
    # sliding window inference for one image at every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=1,
                            pin_memory=torch.cuda.is_available())

    device = torch.device("cuda:0")
    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(torch.load("best_metric_model.pth"))
    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        saver = NiftiSaver(output_dir="./output")
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(
                device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            value = compute_meandice(y_pred=val_outputs,
                                     y=val_labels,
                                     include_background=True,
                                     to_onehot_y=False,
                                     add_sigmoid=True)
            metric_count += len(value)
            metric_sum += value.sum().item()
            val_outputs = (val_outputs.sigmoid() >= 0.5).float()
            saver.save_batch(val_outputs, val_data[2])
        metric = metric_sum / metric_count
        print("evaluation metric:", metric)
    shutil.rmtree(tempdir)
Exemple #6
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    tempdir = tempfile.mkdtemp()
    print('generating synthetic data to {} (this may take a while)'.format(
        tempdir))
    for i in range(5):
        im, seg = create_test_image_3d(128,
                                       128,
                                       128,
                                       num_seg_classes=1,
                                       channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

    images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
    segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))
    val_files = [{'img': img, 'seg': seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=['img', 'seg']),
        AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1),
        ScaleIntensityd(keys=['img', 'seg']),
        ToTensord(keys=['img', 'seg'])
    ])
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate,
                            pin_memory=torch.cuda.is_available())

    device = torch.device('cuda:0')
    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(torch.load('best_metric_model.pth'))
    model.eval()
    with torch.no_grad():
        metric_sum = 0.
        metric_count = 0
        saver = NiftiSaver(output_dir='./output')
        for val_data in val_loader:
            val_images, val_labels = val_data['img'].to(
                device), val_data['seg'].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            value = compute_meandice(y_pred=val_outputs,
                                     y=val_labels,
                                     include_background=True,
                                     to_onehot_y=False,
                                     add_sigmoid=True)
            metric_count += len(value)
            metric_sum += value.sum().item()
            val_outputs = (val_outputs.sigmoid() >= 0.5).float()
            saver.save_batch(
                val_outputs, {
                    'filename_or_obj': val_data['img.filename_or_obj'],
                    'affine': val_data['img.affine']
                })
        metric = metric_sum / metric_count
        print('evaluation metric:', metric)
    shutil.rmtree(tempdir)
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadNiftid(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            ScaleIntensityd(keys="img"),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
    # try to use all the available GPUs
    devices = get_devices_spec(None)
    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(devices[0])

    model.load_state_dict(torch.load("best_metric_model_segmentation3d_dict.pth"))

    # if we have multiple GPUs, set data parallel to execute sliding window inference
    if len(devices) > 1:
        model = torch.nn.DataParallel(model, device_ids=devices)

    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        saver = NiftiSaver(output_dir="./output")
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(devices[0]), val_data["seg"].to(devices[0])
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            val_outputs = post_trans(val_outputs)
            value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
            metric_count += len(value)
            metric_sum += value.item() * len(value)
            saver.save_batch(val_outputs, val_data["img_meta_dict"])
        metric = metric_sum / metric_count
        print("evaluation metric:", metric)
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

model.load_state_dict(torch.load('best_metric_model.pth'))
model.eval()
with torch.no_grad():
    metric_sum = 0.
    metric_count = 0
    saver = NiftiSaver(output_dir='./output')
    for val_data in val_loader:
        val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
        # define sliding window size and batch size for windows inference
        roi_size = (96, 96, 96)
        sw_batch_size = 4
        val_outputs = sliding_window_inference(val_images, roi_size,
                                               sw_batch_size, model)
        value = compute_meandice(y_pred=val_outputs,
                                 y=val_labels,
                                 include_background=True,
                                 to_onehot_y=False,
                                 add_sigmoid=True)
        metric_count += len(value)
        metric_sum += value.sum().item()
        val_outputs = (val_outputs.sigmoid() >= 0.5).float()
        saver.save_batch(val_outputs, val_data[2])
    metric = metric_sum / metric_count
    print('evaluation metric:', metric)
shutil.rmtree(tempdir)
Exemple #9
0
def main():
    """
    Read input and configuration parameters
    """
    parser = argparse.ArgumentParser(
        description='Run inference with basic UNet with MONAI.')
    parser.add_argument('--config',
                        dest='config',
                        metavar='config',
                        type=str,
                        help='config file')
    args = parser.parse_args()

    with open(args.config) as f:
        config_info = yaml.load(f, Loader=yaml.FullLoader)

    # print to log the parameter setups
    print(yaml.dump(config_info))

    # GPU params
    cuda_device = config_info['device']['cuda_device']
    num_workers = config_info['device']['num_workers']
    # inference params
    batch_size_inference = config_info['inference']['batch_size_inference']
    # temporary check as sliding window inference does not accept higher batch size
    assert batch_size_inference == 1
    prob_thr = config_info['inference']['probability_threshold']
    model_to_load = config_info['inference']['model_to_load']
    if not os.path.exists(model_to_load):
        raise IOError('Trained model not found')
    # data params
    data_root = config_info['data']['data_root']
    inference_list = config_info['data']['inference_list']
    # output saving
    out_dir = config_info['output']['out_dir']

    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    torch.cuda.set_device(cuda_device)
    """
    Data Preparation
    """
    val_files = create_data_list(data_folder_list=data_root,
                                 subject_list=inference_list,
                                 img_postfix='_Image',
                                 is_inference=True)

    print(len(val_files))
    print(val_files[0])
    print(val_files[-1])

    # data preprocessing for inference:
    # - convert data to right format [batch, channel, dim, dim, dim]
    # - apply whitening
    # - NOTE: resizing needs to be applied afterwards, otherwise it cannot be remapped back to original size
    val_transforms = Compose([
        LoadNiftid(keys=['img']),
        AddChanneld(keys=['img']),
        NormalizeIntensityd(keys=['img']),
        ToTensord(keys=['img'])
    ])
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=batch_size_inference,
                            num_workers=num_workers)
    """
    Network preparation
    """
    device = torch.cuda.current_device()
    # Create UNet, DiceLoss and Adam optimizer.
    net = monai.networks.nets.UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    net.load_state_dict(torch.load(model_to_load))
    net.eval()
    """
    Run inference
    """
    with torch.no_grad():
        saver = NiftiSaver(output_dir=out_dir)
        for val_data in val_loader:
            val_images = val_data['img'].to(device)
            orig_size = list(val_images.shape)
            resized_size = copy.deepcopy(orig_size)
            resized_size[2] = 96
            resized_size[3] = 96
            val_images_resize = torch.nn.functional.interpolate(
                val_images, size=resized_size[2:], mode='trilinear')
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 1)
            val_outputs = sliding_window_inference(val_images_resize, roi_size,
                                                   batch_size_inference, net)
            val_outputs = (val_outputs.sigmoid() >= prob_thr).float()
            val_outputs_resized = torch.nn.functional.interpolate(
                val_outputs, size=orig_size[2:], mode='nearest')
            # add post-processing
            val_outputs_resized = val_outputs_resized.detach().cpu().numpy()
            strt = ndimage.generate_binary_structure(3, 2)
            post = padded_binary_closing(np.squeeze(val_outputs_resized), strt)
            post = get_largest_component(post)
            val_outputs_resized = val_outputs_resized * post
            # out = np.zeros(img.shape[:-1], np.uint8)
            # out = set_ND_volume_roi_with_bounding_box_range(out, bb_min, bb_max, out_roi)

            saver.save_batch(
                val_outputs_resized, {
                    'filename_or_obj': val_data['img.filename_or_obj'],
                    'affine': val_data['img.affine']
                })
Exemple #10
0
model.load_state_dict(torch.load('best_metric_model.pth'))
model.eval()
with torch.no_grad():
    metric_sum = 0.
    metric_count = 0
    saver = NiftiSaver(output_dir='./output')
    for val_data in val_loader:
        val_images, val_labels = val_data['img'].to(
            device), val_data['seg'].to(device)
        # define sliding window size and batch size for windows inference
        roi_size = (96, 96, 96)
        sw_batch_size = 4
        val_outputs = sliding_window_inference(val_images, roi_size,
                                               sw_batch_size, model)
        value = compute_meandice(y_pred=val_outputs,
                                 y=val_labels,
                                 include_background=True,
                                 to_onehot_y=False,
                                 add_sigmoid=True)
        metric_count += len(value)
        metric_sum += value.sum().item()
        val_outputs = (val_outputs.sigmoid() >= 0.5).float()
        saver.save_batch(
            val_outputs, {
                'filename_or_obj': val_data['img.filename_or_obj'],
                'affine': val_data['img.affine']
            })
    metric = metric_sum / metric_count
    print('evaluation metric:', metric)
shutil.rmtree(tempdir)
Exemple #11
0
        logits = forward(model, inputs)
    labels = logits.argmax(dim=CHANNELS_DIMENSION, keepdim=True)
    batch_mri = inputs
    batch_label = labels
    #slices = torch.cat((batch_mri, batch_label))
    #slices = torch.cat((batch_mri, batch_label),dim=1)
    #inf_path = 'inference.png'
    #save_image(slices, inf_path, nrow=training_batch_size//2, normalize=True, scale_each=True, padding=0)
    #display.Image(inf_path)

    #saver = NiftiSaver(output_dir="./niftinferece",output_postfix = str(i))
    #saver.save_batch(slices)

    saver1 = NiftiSaver(output_dir="./inputsnifti", output_postfix=str(i))
    saver2 = NiftiSaver(output_dir="./labelsnifti", output_postfix=str(i))
    saver1.save_batch(inputs)
    saver2.save_batch(labels)

    #Dice score for inference slide
    dice_score.append(
        get_dice_score(F.softmax(logits, dim=CHANNELS_DIMENSION), targets))
    #Dice loss for inference slide
    dice_losses.append(
        get_dice_loss(F.softmax(logits, dim=CHANNELS_DIMENSION), targets))

## experimental to output all dice scores -- WORKS!
dataset_iter = iter(validation_loader2)
for i in range(5):
    try:
        inputs, targets = prepare_batch(
            next(dataset_iter), 16,
Exemple #12
0
def run_inference_test(root_dir, device="cuda:0"):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadImaged(keys=["img", "seg"]),
        EnsureChannelFirstd(keys=["img", "seg"]),
        # resampling with align_corners=True or dtype=float64 will generate
        # slight different results between PyTorch 1.5 an 1.6
        Spacingd(keys=["img", "seg"],
                 pixdim=[1.2, 0.8, 0.7],
                 mode=["bilinear", "nearest"],
                 dtype=np.float32),
        ScaleIntensityd(keys="img"),
        ToTensord(keys=["img", "seg"]),
    ])
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inference need to input 1 image in every iteration
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    val_post_tran = Compose([
        ToTensor(),
        Activations(sigmoid=True),
        AsDiscrete(threshold_values=True)
    ])
    dice_metric = DiceMetric(include_background=True,
                             reduction="mean",
                             get_not_nans=False)

    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    model.load_state_dict(torch.load(model_filename))
    with eval_mode(model):
        # resampling with align_corners=True or dtype=float64 will generate
        # slight different results between PyTorch 1.5 an 1.6
        saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"),
                           dtype=np.float32)
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            sw_batch_size, roi_size = 4, (96, 96, 96)
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            # decollate prediction into a list and execute post processing for every item
            val_outputs = [
                val_post_tran(i) for i in decollate_batch(val_outputs)
            ]
            # compute metrics
            dice_metric(y_pred=val_outputs, y=val_labels)
            saver.save_batch(val_outputs, val_data["img_meta_dict"])

    return dice_metric.aggregate().item()