示例#1
0
def check_roi_random_crop_error(shape_like_in=None,
                                in_shape=None,
                                crop_shape=None,
                                roi_start=None,
                                roi_shape=None,
                                roi_end=None):
    ndim = 2
    batch_size = 3
    niter = 3
    pipe = dali.pipeline.Pipeline(batch_size=batch_size,
                                  num_threads=4,
                                  device_id=0,
                                  seed=1234)
    with pipe:
        inputs = [] if shape_like_in is None else [shape_like_in]
        out = fn.roi_random_crop(*inputs,
                                 in_shape=in_shape,
                                 crop_shape=crop_shape,
                                 roi_start=roi_start,
                                 roi_shape=roi_shape,
                                 roi_end=roi_end,
                                 device='cpu')
    pipe.set_outputs(out)
    with assert_raises(RuntimeError):
        pipe.build()
        for _ in range(niter):
            outputs = pipe.run()
示例#2
0
    def biased_crop_fn(self, img, label):
        roi_start, roi_end = fn.segmentation.random_object_bbox(
            label,
            format="start_end",
            foreground_prob=self.oversampling,
            background=0,
            seed=self.internal_seed,
            device="cpu",
            cache_objects=True,
        )

        anchor = fn.roi_random_crop(label, roi_start=roi_start, roi_end=roi_end, crop_shape=[1, *self.patch_size])
        anchor = fn.slice(anchor, 1, 3, axes=[0])  # drop channels from anchor
        img, label = fn.slice(
            [img, label], anchor, self.crop_shape, axis_names="DHW", out_of_bounds_policy="pad", device="cpu"
        )

        return img.gpu(), label.gpu()
def check_roi_random_crop(ndim=2, max_batch_size=16,
                          roi_min_start = 0, roi_max_start = 100,
                          roi_min_extent = 20, roi_max_extent = 50,
                          crop_min_extent = 20, crop_max_extent = 50,
                          in_shape_min = 400, in_shape_max = 500,
                          niter=3):
    pipe = dali.pipeline.Pipeline(batch_size=max_batch_size, num_threads=4, device_id=0, seed=1234)
    with pipe:
        assert in_shape_min < in_shape_max
        shape_gen_fn = lambda: random_shape(in_shape_min, in_shape_max, ndim)
        data_gen_f = lambda: batch_gen(max_batch_size, shape_gen_fn)
        shape_like_in = dali.fn.external_source(data_gen_f, device='cpu')
        in_shape = dali.fn.shapes(shape_like_in, dtype=types.INT32)

        crop_shape =  [(crop_min_extent + crop_max_extent) // 2] * ndim if random.choice([True, False]) \
            else  fn.random.uniform(range=(crop_min_extent, crop_max_extent + 1), 
                                    shape=(ndim,), dtype=types.INT32, device='cpu')

        if random.choice([True, False]):
            roi_shape = [(roi_min_extent + roi_max_extent) // 2] * ndim
            roi_start = [(roi_min_start + roi_max_start) // 2] * ndim
            roi_end = [roi_start[d] + roi_shape[d] for d in range(ndim)]
        else:
            roi_shape = fn.random.uniform(range=(roi_min_extent, roi_max_extent + 1),
                                          shape=(ndim,), dtype=types.INT32, device='cpu')
            roi_start = fn.random.uniform(range=(roi_min_start, roi_max_start + 1),
                                          shape=(ndim,), dtype=types.INT32, device='cpu')
            roi_end = roi_start + roi_shape

        outs = [
            fn.roi_random_crop(crop_shape=crop_shape,
                               roi_start=roi_start, roi_shape=roi_shape,
                               device='cpu'),
            fn.roi_random_crop(crop_shape=crop_shape,
                               roi_start=roi_start, roi_end=roi_end,
                               device='cpu'),
            fn.roi_random_crop(shape_like_in, crop_shape=crop_shape,
                               roi_start=roi_start, roi_shape=roi_shape,
                               device='cpu'),
            fn.roi_random_crop(shape_like_in, crop_shape=crop_shape,
                               roi_start=roi_start, roi_end=roi_end,
                               device='cpu'),
            fn.roi_random_crop(in_shape=in_shape, crop_shape=crop_shape,
                               roi_start=roi_start, roi_shape=roi_shape,
                               device='cpu'),
            fn.roi_random_crop(in_shape=in_shape, crop_shape=crop_shape,
                               roi_start=roi_start, roi_end=roi_end,
                               device='cpu'),
        ]

    outputs = [in_shape, roi_start, roi_shape, crop_shape, *outs]
    pipe.set_outputs(*outputs)
    pipe.build()
    for _ in range(niter):
        outputs = pipe.run()
        batch_size = len(outputs[0])
        for s in range(batch_size):
            in_shape = np.array(outputs[0][s]).tolist()
            roi_start = np.array(outputs[1][s]).tolist()
            roi_shape = np.array(outputs[2][s]).tolist()
            crop_shape = np.array(outputs[3][s]).tolist()

            def check_crop_start(crop_start, roi_start, roi_shape, crop_shape, in_shape=None):
                ndim = len(crop_start)
                roi_end = [roi_start[d] + roi_shape[d] for d in range(ndim)]
                crop_end = [crop_start[d] + crop_shape[d] for d in range(ndim)]
                for d in range(ndim):
                    if in_shape is not None:
                        assert crop_start[d] >= 0
                        assert crop_end[d] <= in_shape[d]

                    if crop_shape[d] >= roi_shape[d]:
                        assert crop_start[d] <= roi_start[d]
                        assert crop_end[d] >= roi_end[d]
                    else:
                        assert crop_start[d] >= roi_start[d]
                        assert crop_end[d] <= roi_end[d]
            for idx in range(4, 6):
                check_crop_start(np.array(outputs[idx][s]).tolist(), roi_start, roi_shape, crop_shape)
            for idx in range(6, 10):
                check_crop_start(np.array(outputs[idx][s]).tolist(), roi_start, roi_shape, crop_shape, in_shape)