def check_roi_random_crop_error(shape_like_in=None, in_shape=None, crop_shape=None, roi_start=None, roi_shape=None, roi_end=None): ndim = 2 batch_size = 3 niter = 3 pipe = dali.pipeline.Pipeline(batch_size=batch_size, num_threads=4, device_id=0, seed=1234) with pipe: inputs = [] if shape_like_in is None else [shape_like_in] out = fn.roi_random_crop(*inputs, in_shape=in_shape, crop_shape=crop_shape, roi_start=roi_start, roi_shape=roi_shape, roi_end=roi_end, device='cpu') pipe.set_outputs(out) with assert_raises(RuntimeError): pipe.build() for _ in range(niter): outputs = pipe.run()
def biased_crop_fn(self, img, label): roi_start, roi_end = fn.segmentation.random_object_bbox( label, format="start_end", foreground_prob=self.oversampling, background=0, seed=self.internal_seed, device="cpu", cache_objects=True, ) anchor = fn.roi_random_crop(label, roi_start=roi_start, roi_end=roi_end, crop_shape=[1, *self.patch_size]) anchor = fn.slice(anchor, 1, 3, axes=[0]) # drop channels from anchor img, label = fn.slice( [img, label], anchor, self.crop_shape, axis_names="DHW", out_of_bounds_policy="pad", device="cpu" ) return img.gpu(), label.gpu()
def check_roi_random_crop(ndim=2, max_batch_size=16, roi_min_start = 0, roi_max_start = 100, roi_min_extent = 20, roi_max_extent = 50, crop_min_extent = 20, crop_max_extent = 50, in_shape_min = 400, in_shape_max = 500, niter=3): pipe = dali.pipeline.Pipeline(batch_size=max_batch_size, num_threads=4, device_id=0, seed=1234) with pipe: assert in_shape_min < in_shape_max shape_gen_fn = lambda: random_shape(in_shape_min, in_shape_max, ndim) data_gen_f = lambda: batch_gen(max_batch_size, shape_gen_fn) shape_like_in = dali.fn.external_source(data_gen_f, device='cpu') in_shape = dali.fn.shapes(shape_like_in, dtype=types.INT32) crop_shape = [(crop_min_extent + crop_max_extent) // 2] * ndim if random.choice([True, False]) \ else fn.random.uniform(range=(crop_min_extent, crop_max_extent + 1), shape=(ndim,), dtype=types.INT32, device='cpu') if random.choice([True, False]): roi_shape = [(roi_min_extent + roi_max_extent) // 2] * ndim roi_start = [(roi_min_start + roi_max_start) // 2] * ndim roi_end = [roi_start[d] + roi_shape[d] for d in range(ndim)] else: roi_shape = fn.random.uniform(range=(roi_min_extent, roi_max_extent + 1), shape=(ndim,), dtype=types.INT32, device='cpu') roi_start = fn.random.uniform(range=(roi_min_start, roi_max_start + 1), shape=(ndim,), dtype=types.INT32, device='cpu') roi_end = roi_start + roi_shape outs = [ fn.roi_random_crop(crop_shape=crop_shape, roi_start=roi_start, roi_shape=roi_shape, device='cpu'), fn.roi_random_crop(crop_shape=crop_shape, roi_start=roi_start, roi_end=roi_end, device='cpu'), fn.roi_random_crop(shape_like_in, crop_shape=crop_shape, roi_start=roi_start, roi_shape=roi_shape, device='cpu'), fn.roi_random_crop(shape_like_in, crop_shape=crop_shape, roi_start=roi_start, roi_end=roi_end, device='cpu'), fn.roi_random_crop(in_shape=in_shape, crop_shape=crop_shape, roi_start=roi_start, roi_shape=roi_shape, device='cpu'), fn.roi_random_crop(in_shape=in_shape, crop_shape=crop_shape, roi_start=roi_start, roi_end=roi_end, device='cpu'), ] outputs = [in_shape, roi_start, roi_shape, crop_shape, *outs] pipe.set_outputs(*outputs) pipe.build() for _ in range(niter): outputs = pipe.run() batch_size = len(outputs[0]) for s in range(batch_size): in_shape = np.array(outputs[0][s]).tolist() roi_start = np.array(outputs[1][s]).tolist() roi_shape = np.array(outputs[2][s]).tolist() crop_shape = np.array(outputs[3][s]).tolist() def check_crop_start(crop_start, roi_start, roi_shape, crop_shape, in_shape=None): ndim = len(crop_start) roi_end = [roi_start[d] + roi_shape[d] for d in range(ndim)] crop_end = [crop_start[d] + crop_shape[d] for d in range(ndim)] for d in range(ndim): if in_shape is not None: assert crop_start[d] >= 0 assert crop_end[d] <= in_shape[d] if crop_shape[d] >= roi_shape[d]: assert crop_start[d] <= roi_start[d] assert crop_end[d] >= roi_end[d] else: assert crop_start[d] >= roi_start[d] assert crop_end[d] <= roi_end[d] for idx in range(4, 6): check_crop_start(np.array(outputs[idx][s]).tolist(), roi_start, roi_shape, crop_shape) for idx in range(6, 10): check_crop_start(np.array(outputs[idx][s]).tolist(), roi_start, roi_shape, crop_shape, in_shape)