Exemplo n.º 1
0
    def test_shape(self, spatial_dim):
        spatial_dim = int(spatial_dim)

        model = UNet(
            spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16), strides=(2, 2), num_res_units=2
        )

        device = "cuda:0" if torch.cuda.is_available() else "cpu"
        model.to(device)
        model.eval()

        # Initialize a dummy 3D tensor volume with shape (N,C,D,H,W)
        input_volume = torch.ones(1, 1, 64, 256, 256, device=device)

        # Remove spatial dim to slide across from the roi_size
        roi_size = list(input_volume.shape[2:])
        roi_size.pop(spatial_dim)

        # Initialize and run inferer
        inferer = SliceInferer(roi_size=roi_size, spatial_dim=spatial_dim, sw_batch_size=1, cval=-1)
        result = inferer(input_volume, model)

        self.assertTupleEqual(result.shape, input_volume.shape)

        # test that the inferer can be run multiple times
        result = inferer(input_volume, model)
Exemplo n.º 2
0
    AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1),
    ScaleIntensityd(keys=['img', 'seg']),
    ToTensord(keys=['img', 'seg'])
])
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)

device = torch.device("cuda:0")
net = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
)
net.to(device)

# define sliding window size and batch size for windows inference
roi_size = (96, 96, 96)
sw_batch_size = 4


def _sliding_window_processor(engine, batch):
    net.eval()
    with torch.no_grad():
        val_images, val_labels = batch['img'].to(device), batch['seg'].to(
            device)
        seg_probs = sliding_window_inference(val_images, roi_size,
                                             sw_batch_size, net)
        return seg_probs, val_labels
def main():
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    tempdir = tempfile.mkdtemp()
    print('generating synthetic data to {} (this may take a while)'.format(tempdir))
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

    images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
    segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    segtrans = Compose([AddChannel(), ToTensor()])
    ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)

    device = torch.device('cuda:0')
    net = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )
    net.to(device)

    # define sliding window size and batch size for windows inference
    roi_size = (96, 96, 96)
    sw_batch_size = 4


    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch[0].to(device), batch[1].to(device)
            seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
            return seg_probs, val_labels


    evaluator = Engine(_sliding_window_processor)

    # add evaluation metric to the evaluator engine
    MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice')

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
    val_stats_handler = StatsHandler(
        name='evaluator',
        output_transform=lambda x: None  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # for the array data format, assume the 3rd item of batch data is the meta_data
    file_saver = SegmentationSaver(
        output_dir='tempdir', output_ext='.nii.gz', output_postfix='seg', name='evaluator',
        batch_transform=lambda x: x[2], output_transform=lambda output: predict_segmentation(output[0]))
    file_saver.attach(evaluator)

    # the model was trained by "unet_training_array" example
    ckpt_saver = CheckpointLoader(load_path='./runs/net_checkpoint_50.pth', load_dict={'net': net})
    ckpt_saver.attach(evaluator)

    # sliding window inference for one image at every iteration
    loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
    state = evaluator.run(loader)
    shutil.rmtree(tempdir)
Exemplo n.º 4
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128,
                                       128,
                                       128,
                                       num_seg_classes=1,
                                       channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys=["img", "seg"]),
        ToTensord(keys=["img", "seg"]),
    ])
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)

    device = torch.device("cuda:0")
    net = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )
    net.to(device)

    # define sliding window size and batch size for windows inference
    roi_size = (96, 96, 96)
    sw_batch_size = 4

    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch["img"].to(device), batch["seg"].to(
                device)
            seg_probs = sliding_window_inference(val_images, roi_size,
                                                 sw_batch_size, net)
            return seg_probs, val_labels

    evaluator = Engine(_sliding_window_processor)

    # add evaluation metric to the evaluator engine
    MeanDice(sigmoid=True, to_onehot_y=False).attach(evaluator, "Mean_Dice")

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # convert the necessary metadata from batch data
    SegmentationSaver(
        output_dir="tempdir",
        output_ext=".nii.gz",
        output_postfix="seg",
        name="evaluator",
        batch_transform=lambda batch: batch["img_meta_dict"],
        output_transform=lambda output: predict_segmentation(output[0]),
    ).attach(evaluator)
    # the model was trained by "unet_training_dict" example
    CheckpointLoader(load_path="./runs/net_checkpoint_50.pth",
                     load_dict={
                         "net": net
                     }).attach(evaluator)

    # sliding window inference for one image at every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate,
                            pin_memory=torch.cuda.is_available())
    state = evaluator.run(val_loader)
    print(state)
    shutil.rmtree(tempdir)