def test_SubsampleFourier(): for device in devices: if device == 'gpu': for backend in backends: x = torch.rand(100, 1, 128, 128, 2).cuda().double() y = torch.zeros(100, 1, 8, 8, 2).cuda().double() for i in range(8): for j in range(8): for m in range(16): for n in range(16): y[..., i, j, :] += x[..., i + m * 8, j + n * 8, :] y = y / (16 * 16) subsample_fourier = backend.SubsampleFourier() z = subsample_fourier(x, k=16) assert (y - z).abs().max() < 1e-8 if backend.NAME == 'torch': z = subsample_fourier(x.cpu(), k=16) assert (y.cpu() - z).abs().max() < 1e-8 elif device == 'cpu': for backend in backends: if backend.NAME == 'skcuda': continue x = torch.rand(100, 1, 128, 128, 2).double() y = torch.zeros(100, 1, 8, 8, 2).double() for i in range(8): for j in range(8): for m in range(16): for n in range(16): y[..., i, j, :] += x[..., i + m * 8, j + n * 8, :] y = y / (16 * 16) subsample_fourier = backend.SubsampleFourier() z = subsample_fourier(x, k=16) assert (y - z).abs().max() < 1e-8 if backend.NAME == 'torch': z = subsample_fourier(x.cpu(), k=16) assert (y.cpu() - z).abs().max() < 1e-8 else: raise ('No backend or device detected.')
repeats = 3 x = torch.randn(B, 1, N, N, 2) y = torch.randn(N, N, 2) x = x.cuda() y = y.cuda() x = x.contiguous() y = y.contiguous() subsample_factors = (8, 16, 32, 64) cdgmm = backend.cdgmm modulus = backend.Modulus() subsample = backend.SubsampleFourier() name_list = ['cdgmm', 'modulus'] func_list = [cdgmm, modulus] args_list = [(x, y), (x, )] params_list = [(), ()] for subsample_factor in subsample_factors: name_list.append('subsample') func_list.append(subsample) args_list.append((x, )) params_list.append((subsample_factor, )) for name, func, args, params in zip(name_list, func_list, args_list, params_list): out = func(*(args + params))