def test_save_resized_content(self, output_ext):
        with tempfile.TemporaryDirectory() as tempdir:

            # set up engine
            def _train_func(engine, batch):
                return torch.randint(0, 255, (8, 1, 2, 2)).float()

            engine = Engine(_train_func)

            # set up testing handler
            saver = SegmentationSaver(output_dir=tempdir,
                                      output_postfix="seg",
                                      output_ext=output_ext,
                                      scale=255)
            saver.attach(engine)

            data = [{
                "filename_or_obj":
                ["testfile" + str(i) + ".nii.gz" for i in range(8)],
                "spatial_shape": [(28, 28)] * 8,
                "affine": [np.diag(np.ones(4)) * 5] * 8,
                "original_affine": [np.diag(np.ones(4)) * 1.0] * 8,
            }]
            engine.run(data, max_epochs=1)
            for i in range(8):
                filepath = os.path.join(
                    "testfile" + str(i),
                    "testfile" + str(i) + "_seg" + output_ext)
                self.assertTrue(os.path.exists(os.path.join(tempdir,
                                                            filepath)))
    def test_save_resized_content(self, output_ext):
        default_dir = os.path.join(".", "tempdir")
        shutil.rmtree(default_dir, ignore_errors=True)

        # set up engine
        def _train_func(engine, batch):
            return torch.randint(0, 255, (8, 1, 2, 2)).float()

        engine = Engine(_train_func)

        # set up testing handler
        saver = SegmentationSaver(output_dir=default_dir, output_postfix="seg", output_ext=output_ext)
        saver.attach(engine)

        data = [
            {
                "filename_or_obj": ["testfile" + str(i) for i in range(8)],
                "spatial_shape": [(28, 28)] * 8,
                "affine": [np.diag(np.ones(4)) * 5] * 8,
                "original_affine": [np.diag(np.ones(4)) * 1.0] * 8,
            }
        ]
        engine.run(data, max_epochs=1)
        for i in range(8):
            filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg" + output_ext)
            self.assertTrue(os.path.exists(os.path.join(default_dir, filepath)))
        shutil.rmtree(default_dir)
    def test_saved_content(self, output_ext):
        with tempfile.TemporaryDirectory() as tempdir:

            # set up engine
            def _train_func(engine, batch):
                return torch.randint(0, 255, (8, 1, 2, 2)).float()

            engine = Engine(_train_func)

            # set up testing handler
            saver = SegmentationSaver(output_dir=tempdir,
                                      output_postfix="seg",
                                      output_ext=output_ext,
                                      scale=255)
            saver.attach(engine)

            data = [{
                "filename_or_obj":
                ["testfile" + str(i) + ".nii.gz" for i in range(8)],
                "patch_index":
                list(range(8)),
            }]
            engine.run(data, max_epochs=1)
            for i in range(8):
                filepath = os.path.join(
                    "testfile" + str(i),
                    "testfile" + str(i) + "_seg" + f"_{i}" + output_ext)
                self.assertTrue(os.path.exists(os.path.join(tempdir,
                                                            filepath)))
Exemplo n.º 4
0
def run_test(batch_size, img_name, seg_name, output_dir, device="cuda:0"):
    ds = NiftiDataset([img_name], [seg_name], transform=AddChannel(), seg_transform=AddChannel(), image_only=False)
    loader = DataLoader(ds, batch_size=1, pin_memory=torch.cuda.is_available())

    net = UNet(
        dimensions=3, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2
    ).to(device)
    roi_size = (16, 32, 48)
    sw_batch_size = batch_size

    def _sliding_window_processor(_engine, batch):
        net.eval()
        img, seg, meta_data = batch
        with torch.no_grad():
            seg_probs = sliding_window_inference(img.to(device), roi_size, sw_batch_size, net, device=device)
            return predict_segmentation(seg_probs)

    infer_engine = Engine(_sliding_window_processor)

    SegmentationSaver(
        output_dir=output_dir, output_ext=".nii.gz", output_postfix="seg", batch_transform=lambda x: x[2]
    ).attach(infer_engine)

    infer_engine.run(loader)

    basename = os.path.basename(img_name)[: -len(".nii.gz")]
    saved_name = os.path.join(output_dir, basename, f"{basename}_seg.nii.gz")
    return saved_name
Exemplo n.º 5
0
    def test_saved_content(self, output_ext):
        default_dir = os.path.join(".", "tempdir")
        shutil.rmtree(default_dir, ignore_errors=True)

        # set up engine
        def _train_func(engine, batch):
            return torch.zeros(8, 1, 2, 2)

        engine = Engine(_train_func)

        # set up testing handler
        saver = SegmentationSaver(output_dir=default_dir,
                                  output_postfix="seg",
                                  output_ext=output_ext)
        saver.attach(engine)

        data = [{"filename_or_obj": ["testfile" + str(i) for i in range(8)]}]
        engine.run(data, max_epochs=1)
        for i in range(8):
            filepath = os.path.join("testfile" + str(i),
                                    "testfile" + str(i) + "_seg" + output_ext)
            self.assertTrue(os.path.exists(os.path.join(default_dir,
                                                        filepath)))
        shutil.rmtree(default_dir)
Exemplo n.º 6
0
    def test_saved_content(self):
        default_dir = os.path.join('.', 'tempdir')
        shutil.rmtree(default_dir, ignore_errors=True)

        # set up engine
        def _train_func(engine, batch):
            return torch.zeros(8, 1, 2, 2)

        engine = Engine(_train_func)

        # set up testing handler
        saver = SegmentationSaver(output_dir=default_dir,
                                  output_postfix='seg',
                                  output_ext='.nii.gz')
        saver.attach(engine)

        data = [{'filename_or_obj': ['testfile' + str(i) for i in range(8)]}]
        engine.run(data, max_epochs=1)
        for i in range(8):
            filepath = os.path.join('testfile' + str(i),
                                    'testfile' + str(i) + '_seg.nii.gz')
            self.assertTrue(os.path.exists(os.path.join(default_dir,
                                                        filepath)))
        shutil.rmtree(default_dir)
Exemplo n.º 7
0
# StatsHandler prints loss at every iteration and print metrics at every epoch,
# we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
val_stats_handler = StatsHandler(
    name='evaluator',
    output_transform=lambda x:
    None  # no need to print loss value, so disable per iteration output
)
val_stats_handler.attach(evaluator)

# convert the necessary metadata from batch data
SegmentationSaver(
    output_dir='tempdir',
    output_ext='.nii.gz',
    output_postfix='seg',
    name='evaluator',
    batch_transform=lambda batch: {
        'filename_or_obj': batch['img.filename_or_obj'],
        'affine': batch['img.affine']
    },
    output_transform=lambda output: predict_segmentation(output[0])).attach(
        evaluator)
# the model was trained by "unet_training_dict" example
CheckpointLoader(load_path='./runs/net_checkpoint_50.pth',
                 load_dict={
                     'net': net
                 }).attach(evaluator)

# sliding window inference for one image at every iteration
val_loader = DataLoader(val_ds,
                        batch_size=1,
                        num_workers=4,
def evaluate(args):
    if args.local_rank == 0 and not os.path.exists(args.dir):
        # create 16 random image, mask paris for evaluation
        print(f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(16):
            im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))

    # initialize the distributed evaluation process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a evaluation data loader
    val_ds = Dataset(data=val_files, transform=val_transforms)
    # create a evaluation data sampler
    val_sampler = DistributedSampler(val_ds, shuffle=False)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, sampler=val_sampler)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    # wrap the model with DistributedDataParallel module
    net = DistributedDataParallel(net, device_ids=[device])

    val_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )
    val_handlers = [
        CheckpointLoader(
            load_path="./runs/checkpoint_epoch=4.pt",
            load_dict={"net": net},
            # config mapping to expected GPU device
            map_location={"cuda:0": f"cuda:{args.local_rank}"},
        ),
    ]
    if dist.get_rank() == 0:
        logging.basicConfig(stream=sys.stdout, level=logging.INFO)
        val_handlers.extend(
            [
                StatsHandler(output_transform=lambda x: None),
                SegmentationSaver(
                    output_dir="./runs/",
                    batch_transform=lambda batch: batch["image_meta_dict"],
                    output_transform=lambda output: output["pred"],
                ),
            ]
        )

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice": MeanDice(
                include_background=True,
                output_transform=lambda x: (x["pred"], x["label"]),
                device=device,
            )
        },
        additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]), device=device)},
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False,
    )
    evaluator.run()
    dist.destroy_process_group()
def main():
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    tempdir = tempfile.mkdtemp()
    print('generating synthetic data to {} (this may take a while)'.format(tempdir))
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

    images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
    segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    segtrans = Compose([AddChannel(), ToTensor()])
    ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)

    device = torch.device('cuda:0')
    net = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )
    net.to(device)

    # define sliding window size and batch size for windows inference
    roi_size = (96, 96, 96)
    sw_batch_size = 4


    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch[0].to(device), batch[1].to(device)
            seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
            return seg_probs, val_labels


    evaluator = Engine(_sliding_window_processor)

    # add evaluation metric to the evaluator engine
    MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice')

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
    val_stats_handler = StatsHandler(
        name='evaluator',
        output_transform=lambda x: None  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # for the array data format, assume the 3rd item of batch data is the meta_data
    file_saver = SegmentationSaver(
        output_dir='tempdir', output_ext='.nii.gz', output_postfix='seg', name='evaluator',
        batch_transform=lambda x: x[2], output_transform=lambda output: predict_segmentation(output[0]))
    file_saver.attach(evaluator)

    # the model was trained by "unet_training_array" example
    ckpt_saver = CheckpointLoader(load_path='./runs/net_checkpoint_50.pth', load_dict={'net': net})
    ckpt_saver.attach(evaluator)

    # sliding window inference for one image at every iteration
    loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
    state = evaluator.run(loader)
    shutil.rmtree(tempdir)
def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_workers=4):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys=["image", "label"]),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=num_workers)

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_postprocessing = Compose(
        [
            ToTensord(keys=["pred", "label"]),
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold=0.5),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
            # test the case that `pred` in `engine.state.output`, while `image_meta_dict` in `engine.state.batch`
            SaveImaged(
                keys="pred", meta_keys=PostFix.meta("image"), output_dir=root_dir, output_postfix="seg_transform"
            ),
        ]
    )
    val_handlers = [
        StatsHandler(iteration_log=False),
        CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}),
        SegmentationSaver(
            output_dir=root_dir,
            output_postfix="seg_handler",
            batch_transform=from_engine(PostFix.meta("image")),
            output_transform=from_engine("pred"),
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        postprocessing=val_postprocessing,
        key_val_metric={
            "val_mean_dice": MeanDice(include_background=True, output_transform=from_engine(["pred", "label"]))
        },
        additional_metrics={"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
        val_handlers=val_handlers,
        amp=bool(amp),
    )
    evaluator.run()

    return evaluator.state.best_metric
Exemplo n.º 11
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128,
                                       128,
                                       128,
                                       num_seg_classes=1,
                                       channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys=["img", "seg"]),
        ToTensord(keys=["img", "seg"]),
    ])
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)

    device = torch.device("cuda:0")
    net = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )
    net.to(device)

    # define sliding window size and batch size for windows inference
    roi_size = (96, 96, 96)
    sw_batch_size = 4

    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch["img"].to(device), batch["seg"].to(
                device)
            seg_probs = sliding_window_inference(val_images, roi_size,
                                                 sw_batch_size, net)
            return seg_probs, val_labels

    evaluator = Engine(_sliding_window_processor)

    # add evaluation metric to the evaluator engine
    MeanDice(sigmoid=True, to_onehot_y=False).attach(evaluator, "Mean_Dice")

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # convert the necessary metadata from batch data
    SegmentationSaver(
        output_dir="tempdir",
        output_ext=".nii.gz",
        output_postfix="seg",
        name="evaluator",
        batch_transform=lambda batch: batch["img_meta_dict"],
        output_transform=lambda output: predict_segmentation(output[0]),
    ).attach(evaluator)
    # the model was trained by "unet_training_dict" example
    CheckpointLoader(load_path="./runs/net_checkpoint_50.pth",
                     load_dict={
                         "net": net
                     }).attach(evaluator)

    # sliding window inference for one image at every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate,
                            pin_memory=torch.cuda.is_available())
    state = evaluator.run(val_loader)
    print(state)
    shutil.rmtree(tempdir)
Exemplo n.º 12
0
MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice')

# StatsHandler prints loss at every iteration and print metrics at every epoch,
# we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
val_stats_handler = StatsHandler(
    name='evaluator',
    output_transform=lambda x:
    None  # no need to print loss value, so disable per iteration output
)
val_stats_handler.attach(evaluator)

# for the array data format, assume the 3rd item of batch data is the meta_data
file_saver = SegmentationSaver(
    output_dir='tempdir',
    output_ext='.nii.gz',
    output_postfix='seg',
    name='evaluator',
    batch_transform=lambda x: x[2],
    output_transform=lambda output: predict_segmentation(output[0]))
file_saver.attach(evaluator)

# the model was trained by "unet_training_array" example
ckpt_saver = CheckpointLoader(load_path='./runs/net_checkpoint_50.pth',
                              load_dict={'net': net})
ckpt_saver.attach(evaluator)

# sliding window inference for one image at every iteration
loader = DataLoader(ds,
                    batch_size=1,
                    num_workers=1,
                    pin_memory=torch.cuda.is_available())
Exemplo n.º 13
0
def run_inference(input_data, config_info):
    """
    Pipeline to run inference with MONAI dynUNet model. The pipeline reads the input filenames, applies the required
    preprocessing and creates the pytorch dataloader; it then performs evaluation on each input file using a trained
    dynUNet model (random flipping augmentation is applied at inference).
    It uses the dynUNet model implemented in the MONAI framework
    (https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/nets/dynunet.py)
    which is inspired by the nnU-Net framework (https://arxiv.org/abs/1809.10486)
    Inference is performed in 2D slice-by-slice, all slices are then recombined together into the 3D volume.

    Args:
        input_data: str or list of strings, filenames of images to be processed
        config_info: dict, contains the configuration parameters to reload the trained model

    """
    """
    Read input and configuration parameters
    """

    val_files = create_data_list_of_dictionaries(input_data)

    # print MONAI config information
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    print("*** MONAI config: ")
    print_config()

    # print to log the parameter setups
    print("*** Network inference config: ")
    print(yaml.dump(config_info))

    # inference params
    nr_out_channels = config_info['inference']['nr_out_channels']
    spacing = config_info["inference"]["spacing"]
    prob_thr = config_info['inference']['probability_threshold']
    model_to_load = config_info['inference']['model_to_load']
    if not os.path.exists(model_to_load):
        raise FileNotFoundError('Trained model not found')
    patch_size = config_info["inference"]["inplane_size"] + [1]
    print("Considering patch size = {}".format(patch_size))

    # set up either GPU or CPU usage
    if torch.cuda.is_available():
        print("\n#### GPU INFORMATION ###")
        print("Using device number: {}, name: {}".format(
            torch.cuda.current_device(), torch.cuda.get_device_name()))
        current_device = torch.device("cuda:0")
    else:
        current_device = torch.device("cpu")
        print("Using device: {}".format(current_device))
    """
    Data Preparation
    """
    print("***  Preparing data ... ")
    # data preprocessing for inference:
    # - convert data to right format [batch, channel, dim, dim, dim]
    # - resample to the training resolution in-plane (not along z)
    # - apply whitening
    # - convert to tensor
    val_transforms = Compose([
        LoadNiftid(keys=["image"]),
        AddChanneld(keys=["image"]),
        InPlaneSpacingd(
            keys=["image"],
            pixdim=spacing,
            mode="bilinear",
        ),
        NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True),
        ToTensord(keys=["image"]),
    ])
    # create a validation data loader
    val_ds = Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=config_info['device']['num_workers'])

    def prepare_batch(batchdata):
        assert isinstance(batchdata,
                          dict), "prepare_batch expects dictionary input data."
        return ((batchdata["image"],
                 batchdata["label"]) if "label" in batchdata else
                (batchdata["image"], None))

    """
    Network preparation
    """
    print("***  Preparing network ... ")
    # automatically extracts the strides and kernels based on nnU-Net empirical rules
    spacings = spacing[:2]
    sizes = patch_size[:2]
    strides, kernels = [], []
    while True:
        spacing_ratio = [sp / min(spacings) for sp in spacings]
        stride = [
            2 if ratio <= 2 and size >= 8 else 1
            for (ratio, size) in zip(spacing_ratio, sizes)
        ]
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
        if all(s == 1 for s in stride):
            break
        sizes = [i / j for i, j in zip(sizes, stride)]
        spacings = [i * j for i, j in zip(spacings, stride)]
        kernels.append(kernel)
        strides.append(stride)
    strides.insert(0, len(spacings) * [1])
    kernels.append(len(spacings) * [3])

    net = DynUNet(spatial_dims=2,
                  in_channels=1,
                  out_channels=nr_out_channels,
                  kernel_size=kernels,
                  strides=strides,
                  upsample_kernel_size=strides[1:],
                  norm_name="instance",
                  deep_supervision=True,
                  deep_supr_num=2,
                  res_block=False).to(current_device)
    """
    Set ignite evaluator to perform inference
    """
    print("***  Preparing evaluator ... ")
    if nr_out_channels == 1:
        do_sigmoid = True
        do_softmax = False
    elif nr_out_channels > 1:
        do_sigmoid = False
        do_softmax = True
    else:
        raise Exception("incompatible number of output channels")
    print("Using sigmoid={} and softmax={} as final activation".format(
        do_sigmoid, do_softmax))
    val_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=do_sigmoid, softmax=do_softmax),
        AsDiscreted(keys="pred",
                    argmax=True,
                    threshold_values=True,
                    logit_thresh=prob_thr),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=1)
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointLoader(load_path=model_to_load,
                         load_dict={"net": net},
                         map_location=torch.device('cpu')),
        SegmentationSaver(
            output_dir=config_info['output']['out_dir'],
            output_ext='.nii.gz',
            output_postfix=config_info['output']['out_postfix'],
            batch_transform=lambda batch: batch["image_meta_dict"],
            output_transform=lambda output: output["pred"],
        ),
    ]

    # Define customized evaluator
    class DynUNetEvaluator(SupervisedEvaluator):
        def _iteration(self, engine, batchdata):
            inputs, targets = self.prepare_batch(batchdata)
            inputs = inputs.to(engine.state.device)
            if targets is not None:
                targets = targets.to(engine.state.device)
            flip_inputs_1 = torch.flip(inputs, dims=(2, ))
            flip_inputs_2 = torch.flip(inputs, dims=(3, ))
            flip_inputs_3 = torch.flip(inputs, dims=(2, 3))

            def _compute_pred():
                pred = self.inferer(inputs, self.network)
                # use random flipping as data augmentation at inference
                flip_pred_1 = torch.flip(self.inferer(flip_inputs_1,
                                                      self.network),
                                         dims=(2, ))
                flip_pred_2 = torch.flip(self.inferer(flip_inputs_2,
                                                      self.network),
                                         dims=(3, ))
                flip_pred_3 = torch.flip(self.inferer(flip_inputs_3,
                                                      self.network),
                                         dims=(2, 3))
                return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4

            # execute forward computation
            self.network.eval()
            with torch.no_grad():
                if self.amp:
                    with torch.cuda.amp.autocast():
                        predictions = _compute_pred()
                else:
                    predictions = _compute_pred()
            return {"image": inputs, "label": targets, "pred": predictions}

    evaluator = DynUNetEvaluator(
        device=current_device,
        val_data_loader=val_loader,
        network=net,
        prepare_batch=prepare_batch,
        inferer=SlidingWindowInferer2D(roi_size=patch_size,
                                       sw_batch_size=4,
                                       overlap=0.0),
        post_transform=val_post_transforms,
        val_handlers=val_handlers,
        amp=False,
    )
    """
    Run inference
    """
    print("***  Running evaluator ... ")
    evaluator.run()
    print("Done!")

    return
Exemplo n.º 14
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)]

    # model file path
    model_file = glob("./runs/net_key_metric*")[0]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointLoader(load_path=model_file, load_dict={"net": net}),
        SegmentationSaver(
            output_dir="./runs/",
            batch_transform=lambda batch: batch["image_meta_dict"],
            output_transform=lambda output: output["pred"],
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))},
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False,
    )
    evaluator.run()
Exemplo n.º 15
0
def run_inference_test(root_dir, model_file, device="cuda:0", amp=False):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        ToTensord(keys=["image", "label"]),
    ])

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}),
        SegmentationSaver(
            output_dir=root_dir,
            batch_transform=lambda batch: batch["image_meta_dict"],
            output_transform=lambda output: output["pred"],
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96),
                                     sw_batch_size=4,
                                     overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers,
        amp=True if amp else False,
    )
    evaluator.run()

    return evaluator.state.best_metric
def main(tempdir):
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    segtrans = Compose([AddChannel(), ToTensor()])
    ds = ImageDataset(images,
                      segs,
                      transform=imtrans,
                      seg_transform=segtrans,
                      image_only=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    # define sliding window size and batch size for windows inference
    roi_size = (96, 96, 96)
    sw_batch_size = 4

    post_trans = Compose(
        [Activations(sigmoid=True),
         AsDiscrete(threshold_values=True)])

    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch[0].to(device), batch[1].to(device)
            seg_probs = sliding_window_inference(val_images, roi_size,
                                                 sw_batch_size, net)
            seg_probs = post_trans(seg_probs)
            return seg_probs, val_labels

    evaluator = Engine(_sliding_window_processor)

    # add evaluation metric to the evaluator engine
    MeanDice().attach(evaluator, "Mean_Dice")

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # for the array data format, assume the 3rd item of batch data is the meta_data
    file_saver = SegmentationSaver(
        output_dir="tempdir",
        output_ext=".nii.gz",
        output_postfix="seg",
        name="evaluator",
        batch_transform=lambda x: x[2],
        output_transform=lambda output: output[0],
    )
    file_saver.attach(evaluator)

    # the model was trained by "unet_training_array" example
    ckpt_saver = CheckpointLoader(
        load_path="./runs_array/net_checkpoint_100.pt", load_dict={"net": net})
    ckpt_saver.attach(evaluator)

    # sliding window inference for one image at every iteration
    loader = DataLoader(ds,
                        batch_size=1,
                        num_workers=1,
                        pin_memory=torch.cuda.is_available())
    state = evaluator.run(loader)
    print(state)
Exemplo n.º 17
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask paris
    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128,
                                       128,
                                       128,
                                       num_seg_classes=1,
                                       channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        ToTensord(keys=["image", "label"]),
    ])

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda:0")
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_post_transforms = Compose([
        Activationsd(keys="pred", output_postfix="act", sigmoid=True),
        AsDiscreted(keys="pred_act",
                    output_postfix="dis",
                    threshold_values=True),
        KeepLargestConnectedComponentd(keys="pred_act_dis",
                                       applied_values=[1],
                                       output_postfix=None),
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointLoader(load_path="./runs/net_key_metric=0.9101.pth",
                         load_dict={"net": net}),
        SegmentationSaver(
            output_dir="./runs/",
            batch_transform=lambda x: {
                "filename_or_obj": x["image.filename_or_obj"],
                "affine": x["image.affine"],
                "original_affine": x["image.original_affine"],
                "spatial_shape": x["image.spatial_shape"],
            },
            output_transform=lambda x: x["pred_act_dis"],
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96),
                                     sw_batch_size=4,
                                     overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x:
                     (x["pred_act_dis"], x["label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(
                output_transform=lambda x: (x["pred_act_dis"], x["label"]))
        },
        val_handlers=val_handlers,
    )
    evaluator.run()
    shutil.rmtree(tempdir)