def test_array_to_blocks(self): xps = [np] if config.cupy_enabled: xps.append(cp) for xp in xps: for dtype in [np.float32, np.complex64]: for ndim in [1, 2, 3]: with self.subTest(xp=xp, dtype=dtype, ndim=ndim): input = xp.array([0, 1, 2, 3, 4, 5], dtype=dtype).reshape([6] + [1] * (ndim - 1)) blk_shape = [1] + [1] * (ndim - 1) blk_strides = [1] + [1] * (ndim - 1) output = xp.array( [[0], [1], [2], [3], [4], [5]], dtype=dtype).reshape([6] + [1] * (ndim - 1) + [1] + [1] * (ndim - 1)) xp.testing.assert_allclose( output, block.array_to_blocks(input, blk_shape, blk_strides)) blk_shape = [2] + [1] * (ndim - 1) blk_strides = [1] + [1] * (ndim - 1) output = xp.array( [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5]], dtype=dtype).reshape([5] + [1] * (ndim - 1) + [2] + [1] * (ndim - 1)) xp.testing.assert_allclose( output, block.array_to_blocks(input, blk_shape, blk_strides)) blk_shape = [2] + [1] * (ndim - 1) blk_strides = [2] + [1] * (ndim - 1) output = xp.array( [[0, 1], [2, 3], [4, 5]], dtype=dtype).reshape([3] + [1] * (ndim - 1) + [2] + [1] * (ndim - 1)) xp.testing.assert_allclose( output, block.array_to_blocks(input, blk_shape, blk_strides)) blk_shape = [3] + [1] * (ndim - 1) blk_strides = [2] + [1] * (ndim - 1) output = xp.array( [[0, 1, 2], [2, 3, 4]], dtype=dtype).reshape([2] + [1] * (ndim - 1) + [3] + [1] * (ndim - 1)) xp.testing.assert_allclose( output, block.array_to_blocks(input, blk_shape, blk_strides))
def _apply(self, input): return block.array_to_blocks(input, self.blk_shape, self.blk_strides)
def _apply(self, input): device = backend.get_device(input) with device: return block.array_to_blocks(input, self.blk_shape, self.blk_strides)