Exemple #1
0
def run_test(batch_size=64, train_steps=100, device=torch.device("cuda:0")):
    class _TestBatch(Dataset):
        def __getitem__(self, _unused_id):
            im, seg = create_test_image_2d(128,
                                           128,
                                           noise_max=1,
                                           num_objs=4,
                                           num_seg_classes=1)
            return im[None], seg[None].astype(np.float32)

        def __len__(self):
            return train_steps

    net = UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(4, 8, 16, 32),
        strides=(2, 2, 2),
        num_res_units=2,
    )

    loss = DiceLoss(do_sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), 1e-4)
    src = DataLoader(_TestBatch(), batch_size=batch_size)

    trainer = create_supervised_trainer(net, opt, loss, device, False)

    trainer.run(src, 1)
    loss = trainer.state.output
    print('Loss:', loss)
    if loss >= 1:
        print('Loss value is wrong, expect to be < 1.')
    return loss
def run_test(batch_size=2, device=torch.device("cpu:0")):

    im, seg = create_test_image_3d(25,
                                   28,
                                   63,
                                   rad_max=10,
                                   noise_max=1,
                                   num_objs=4,
                                   num_seg_classes=1)
    input_shape = im.shape
    img_name = make_nifti_image(im)
    seg_name = make_nifti_image(seg)
    ds = NiftiDataset([img_name], [seg_name],
                      transform=AddChannel(),
                      seg_transform=AddChannel(),
                      image_only=False)
    loader = DataLoader(ds, batch_size=1, pin_memory=torch.cuda.is_available())

    net = UNet(
        dimensions=3,
        in_channels=1,
        num_classes=1,
        channels=(4, 8, 16, 32),
        strides=(2, 2, 2),
        num_res_units=2,
    )
    roi_size = (16, 32, 48)
    sw_batch_size = batch_size

    def _sliding_window_processor(_engine, batch):
        net.eval()
        img, seg, meta_data = batch
        with torch.no_grad():
            seg_probs = sliding_window_inference(img, roi_size, sw_batch_size,
                                                 lambda x: net(x)[0], device)
            return predict_segmentation(seg_probs)

    infer_engine = Engine(_sliding_window_processor)

    with tempfile.TemporaryDirectory() as temp_dir:
        SegmentationSaver(output_path=temp_dir,
                          output_ext='.nii.gz',
                          output_postfix='seg').attach(infer_engine)

        infer_engine.run(loader)

        basename = os.path.basename(img_name)[:-len('.nii.gz')]
        saved_name = os.path.join(temp_dir, basename,
                                  '{}_seg.nii.gz'.format(basename))
        testing_shape = nib.load(saved_name).get_fdata().shape

    if os.path.exists(img_name):
        os.remove(img_name)
    if os.path.exists(seg_name):
        os.remove(seg_name)

    return testing_shape == input_shape
Exemple #3
0
def get_net(cfg: DictConfig) -> torch.nn.Module:
    if cfg["net_name"] == "UNet":
        return UNet(
            spatial_dims=3,
            in_channels=cfg.in_channels,
            out_channels=cfg.out_channels,
            channels=cfg.channels,
            strides=cfg.strides,
            dropout=cfg.dropout,
        )
    elif cfg["net_name"] == "MyUNet":
        return UNet3D(cfg)
    elif cfg["net_name"] == "UNETR":
        return UNETR(
            in_channels=cfg.in_channels,
            out_channels=cfg.out_channels,
            img_size=cfg.img_size,
            dropout_rate=cfg.dropout,
            res_block=cfg.res_block,
        )

    raise NotImplementedError
Exemple #4
0
segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))

# Define transforms for image and segmentation
imtrans = transforms.Compose([Rescale(), AddChannel()])
segtrans = transforms.Compose([AddChannel()])
ds = NiftiDataset(images,
                  segs,
                  transform=imtrans,
                  seg_transform=segtrans,
                  image_only=False)

device = torch.device("cuda:0")
net = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
)
net.to(device)

# define sliding window size and batch size for windows inference
roi_size = (96, 96, 96)
sw_batch_size = 4


def _sliding_window_processor(engine, batch):
    net.eval()
    img, seg, meta_data = batch
    with torch.no_grad():
        seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, net,
Exemple #5
0
 def test_shape(self, input_param, input_data, expected_shape):
     net = UNet(**input_param)
     net.eval()
     with torch.no_grad():
         result = net.forward(input_data)[1]
         self.assertEqual(result.shape, expected_shape)
Exemple #6
0
 def test_shape(self, input_param, input_data, expected_shape):
     result = UNet(**input_param).forward(input_data)[1]
     self.assertEqual(result.shape, expected_shape)