def test_saved_content(self, output_ext):
        with tempfile.TemporaryDirectory() as tempdir:

            # set up engine
            def _train_func(engine, batch):
                return torch.randint(0, 255, (8, 1, 2, 2)).float()

            engine = Engine(_train_func)

            # set up testing handler
            saver = SegmentationSaver(output_dir=tempdir,
                                      output_postfix="seg",
                                      output_ext=output_ext,
                                      scale=255)
            saver.attach(engine)

            data = [{
                "filename_or_obj":
                ["testfile" + str(i) + ".nii.gz" for i in range(8)],
                "patch_index":
                list(range(8)),
            }]
            engine.run(data, max_epochs=1)
            for i in range(8):
                filepath = os.path.join(
                    "testfile" + str(i),
                    "testfile" + str(i) + "_seg" + f"_{i}" + output_ext)
                self.assertTrue(os.path.exists(os.path.join(tempdir,
                                                            filepath)))
    def test_save_resized_content(self, output_ext):
        with tempfile.TemporaryDirectory() as tempdir:

            # set up engine
            def _train_func(engine, batch):
                return torch.randint(0, 255, (8, 1, 2, 2)).float()

            engine = Engine(_train_func)

            # set up testing handler
            saver = SegmentationSaver(output_dir=tempdir,
                                      output_postfix="seg",
                                      output_ext=output_ext,
                                      scale=255)
            saver.attach(engine)

            data = [{
                "filename_or_obj":
                ["testfile" + str(i) + ".nii.gz" for i in range(8)],
                "spatial_shape": [(28, 28)] * 8,
                "affine": [np.diag(np.ones(4)) * 5] * 8,
                "original_affine": [np.diag(np.ones(4)) * 1.0] * 8,
            }]
            engine.run(data, max_epochs=1)
            for i in range(8):
                filepath = os.path.join(
                    "testfile" + str(i),
                    "testfile" + str(i) + "_seg" + output_ext)
                self.assertTrue(os.path.exists(os.path.join(tempdir,
                                                            filepath)))
    def test_save_resized_content(self, output_ext):
        default_dir = os.path.join(".", "tempdir")
        shutil.rmtree(default_dir, ignore_errors=True)

        # set up engine
        def _train_func(engine, batch):
            return torch.randint(0, 255, (8, 1, 2, 2)).float()

        engine = Engine(_train_func)

        # set up testing handler
        saver = SegmentationSaver(output_dir=default_dir, output_postfix="seg", output_ext=output_ext)
        saver.attach(engine)

        data = [
            {
                "filename_or_obj": ["testfile" + str(i) for i in range(8)],
                "spatial_shape": [(28, 28)] * 8,
                "affine": [np.diag(np.ones(4)) * 5] * 8,
                "original_affine": [np.diag(np.ones(4)) * 1.0] * 8,
            }
        ]
        engine.run(data, max_epochs=1)
        for i in range(8):
            filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg" + output_ext)
            self.assertTrue(os.path.exists(os.path.join(default_dir, filepath)))
        shutil.rmtree(default_dir)
Example #4
0
    def test_saved_content(self, output_ext):
        default_dir = os.path.join(".", "tempdir")
        shutil.rmtree(default_dir, ignore_errors=True)

        # set up engine
        def _train_func(engine, batch):
            return torch.zeros(8, 1, 2, 2)

        engine = Engine(_train_func)

        # set up testing handler
        saver = SegmentationSaver(output_dir=default_dir,
                                  output_postfix="seg",
                                  output_ext=output_ext)
        saver.attach(engine)

        data = [{"filename_or_obj": ["testfile" + str(i) for i in range(8)]}]
        engine.run(data, max_epochs=1)
        for i in range(8):
            filepath = os.path.join("testfile" + str(i),
                                    "testfile" + str(i) + "_seg" + output_ext)
            self.assertTrue(os.path.exists(os.path.join(default_dir,
                                                        filepath)))
        shutil.rmtree(default_dir)
Example #5
0
    def test_saved_content(self):
        default_dir = os.path.join('.', 'tempdir')
        shutil.rmtree(default_dir, ignore_errors=True)

        # set up engine
        def _train_func(engine, batch):
            return torch.zeros(8, 1, 2, 2)

        engine = Engine(_train_func)

        # set up testing handler
        saver = SegmentationSaver(output_dir=default_dir,
                                  output_postfix='seg',
                                  output_ext='.nii.gz')
        saver.attach(engine)

        data = [{'filename_or_obj': ['testfile' + str(i) for i in range(8)]}]
        engine.run(data, max_epochs=1)
        for i in range(8):
            filepath = os.path.join('testfile' + str(i),
                                    'testfile' + str(i) + '_seg.nii.gz')
            self.assertTrue(os.path.exists(os.path.join(default_dir,
                                                        filepath)))
        shutil.rmtree(default_dir)
def main(tempdir):
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    segtrans = Compose([AddChannel(), ToTensor()])
    ds = ImageDataset(images,
                      segs,
                      transform=imtrans,
                      seg_transform=segtrans,
                      image_only=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    # define sliding window size and batch size for windows inference
    roi_size = (96, 96, 96)
    sw_batch_size = 4

    post_trans = Compose(
        [Activations(sigmoid=True),
         AsDiscrete(threshold_values=True)])

    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch[0].to(device), batch[1].to(device)
            seg_probs = sliding_window_inference(val_images, roi_size,
                                                 sw_batch_size, net)
            seg_probs = post_trans(seg_probs)
            return seg_probs, val_labels

    evaluator = Engine(_sliding_window_processor)

    # add evaluation metric to the evaluator engine
    MeanDice().attach(evaluator, "Mean_Dice")

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # for the array data format, assume the 3rd item of batch data is the meta_data
    file_saver = SegmentationSaver(
        output_dir="tempdir",
        output_ext=".nii.gz",
        output_postfix="seg",
        name="evaluator",
        batch_transform=lambda x: x[2],
        output_transform=lambda output: output[0],
    )
    file_saver.attach(evaluator)

    # the model was trained by "unet_training_array" example
    ckpt_saver = CheckpointLoader(
        load_path="./runs_array/net_checkpoint_100.pt", load_dict={"net": net})
    ckpt_saver.attach(evaluator)

    # sliding window inference for one image at every iteration
    loader = DataLoader(ds,
                        batch_size=1,
                        num_workers=1,
                        pin_memory=torch.cuda.is_available())
    state = evaluator.run(loader)
    print(state)
def main():
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    tempdir = tempfile.mkdtemp()
    print('generating synthetic data to {} (this may take a while)'.format(tempdir))
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

    images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
    segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    segtrans = Compose([AddChannel(), ToTensor()])
    ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)

    device = torch.device('cuda:0')
    net = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )
    net.to(device)

    # define sliding window size and batch size for windows inference
    roi_size = (96, 96, 96)
    sw_batch_size = 4


    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch[0].to(device), batch[1].to(device)
            seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
            return seg_probs, val_labels


    evaluator = Engine(_sliding_window_processor)

    # add evaluation metric to the evaluator engine
    MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice')

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
    val_stats_handler = StatsHandler(
        name='evaluator',
        output_transform=lambda x: None  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # for the array data format, assume the 3rd item of batch data is the meta_data
    file_saver = SegmentationSaver(
        output_dir='tempdir', output_ext='.nii.gz', output_postfix='seg', name='evaluator',
        batch_transform=lambda x: x[2], output_transform=lambda output: predict_segmentation(output[0]))
    file_saver.attach(evaluator)

    # the model was trained by "unet_training_array" example
    ckpt_saver = CheckpointLoader(load_path='./runs/net_checkpoint_50.pth', load_dict={'net': net})
    ckpt_saver.attach(evaluator)

    # sliding window inference for one image at every iteration
    loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
    state = evaluator.run(loader)
    shutil.rmtree(tempdir)
Example #8
0
# we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
val_stats_handler = StatsHandler(
    name='evaluator',
    output_transform=lambda x:
    None  # no need to print loss value, so disable per iteration output
)
val_stats_handler.attach(evaluator)

# for the array data format, assume the 3rd item of batch data is the meta_data
file_saver = SegmentationSaver(
    output_dir='tempdir',
    output_ext='.nii.gz',
    output_postfix='seg',
    name='evaluator',
    batch_transform=lambda x: x[2],
    output_transform=lambda output: predict_segmentation(output[0]))
file_saver.attach(evaluator)

# the model was trained by "unet_training_array" example
ckpt_saver = CheckpointLoader(load_path='./runs/net_checkpoint_50.pth',
                              load_dict={'net': net})
ckpt_saver.attach(evaluator)

# sliding window inference for one image at every iteration
loader = DataLoader(ds,
                    batch_size=1,
                    num_workers=1,
                    pin_memory=torch.cuda.is_available())
state = evaluator.run(loader)
shutil.rmtree(tempdir)