def test_convolve_data_adjoint_full(self): mode = 'full' devices = [backend.cpu_device] if config.cupy_enabled: devices.append(backend.Device(0)) for device in devices: xp = device.xp with device: for dtype in dtypes: with self.subTest(dtype=dtype, device=device): output = xp.ones([1, 5], dtype=dtype) filt = xp.ones([1, 3], dtype=dtype) data_shape = [1, 3] data = backend.to_device( conv.convolve_data_adjoint(output, filt, data_shape, mode=mode)) npt.assert_allclose(data, [[3, 3, 3]], atol=1e-5) output = xp.ones([1, 4], dtype=dtype) filt = xp.ones([1, 2], dtype=dtype) data_shape = [1, 3] data = backend.to_device( conv.convolve_data_adjoint(output, filt, data_shape, mode=mode)) npt.assert_allclose(data, [[2, 2, 2]], atol=1e-5) output = xp.ones([2, 1, 5], dtype=dtype) filt = xp.ones([2, 1, 1, 3], dtype=dtype) data_shape = [1, 1, 3] data = backend.to_device( conv.convolve_data_adjoint(output, filt, data_shape, mode=mode, multi_channel=True), backend.cpu_device) npt.assert_allclose(data, [[[6, 6, 6]]], atol=1e-5) output = xp.ones([2, 1, 5], dtype=dtype) filt = xp.ones([2, 1, 1, 3], dtype=dtype) data_shape = [1, 1, 8] strides = [1, 2] data = backend.to_device( conv.convolve_data_adjoint(output, filt, data_shape, mode=mode, strides=strides, multi_channel=True), backend.cpu_device) npt.assert_allclose(data, [[[4, 2, 4, 2, 4, 2, 4, 2]]], atol=1e-5)
def _apply(self, input): return conv.convolve_data_adjoint(input, self.filt, self.oshape, mode=self.mode, strides=self.strides, multi_channel=self.multi_channel)
def _apply(self, input): device = backend.get_device(input) filt = backend.to_device(self.filt, device) with device: return conv.convolve_data_adjoint( input, filt, self.oshape, mode=self.mode, strides=self.strides, multi_channel=self.multi_channel)