def test_divide_combine(self): result_shape = np.array((3, 40, 56, 72)) slices = tuple([slice(None)] + [slice(1, -1)] * 3) a1 = np.random.randn(*self.x_shape) a_parts = divide(a1, self.patch_size, self.stride) a2 = combine([a[slices] for a in a_parts], result_shape, self.stride) np.testing.assert_equal(a1[slices], a2)
def wrapper(x, *args, **kwargs): input_axis = resolve_deprecation(axis, x.ndim, patch_size, stride) local_size, local_stride = broadcast_to_axis( input_axis, patch_size, stride) if valid: shape = extract(x.shape, input_axis) padded_shape = np.maximum(shape, local_size) new_shape = padded_shape + (local_stride - padded_shape + local_size) % local_stride x = pad_to_shape(x, new_shape, input_axis, padding_values, ratio) patches = pmap( predict, divide_grid(x, new_shape, local_size, local_stride, input_axis), *args, **kwargs) # patches = pmap(predict, divide(x, local_size, local_stride, input_axis), *args, **kwargs) prediction = combine(patches, extract(x.shape, input_axis), local_stride, axis) if valid: # print(prediction.shape, shape) prediction = crop_to_shape(prediction, shape, axis, ratio) return prediction
def test_combine_grid_patches(self): stride = patch_size = [20] * 3 for _ in range(20): shape = np.random.randint(40, 50, size=3) with self.subTest(shape=shape): x = np.random.randn(1, *shape) np.testing.assert_array_almost_equal( x, combine(divide(x, patch_size, stride), shape, stride))
def test_combine_int(self): patch_size = np.array([20] * 3, int) stride = patch_size // 2 shape = [45, 43, 48] x = np.random.randint(0, 100, size=(1, *shape)) np.testing.assert_array_almost_equal( x, combine(divide(x, patch_size, stride), shape, stride))
def wrapper(x): if valid: shape = np.array(x.shape)[list(axes)] padded_shape = np.maximum(shape, patch_size) new_shape = padded_shape + (stride - padded_shape + patch_size) % stride x = pad_to_shape(x, new_shape, axes, padding_values, ratio) patches = map(predict, divide(x, patch_size, stride, axes)) prediction = combine(patches, extract(x.shape, axes), stride, axes) if valid: prediction = crop_to_shape(prediction, shape, axes, ratio) return prediction