示例#1
0
def test_SubsampleFourier():
    for device in devices:
        if device == 'gpu':
            for backend in backends:
                x = torch.rand(100, 1, 128, 128, 2).cuda().double()
                y = torch.zeros(100, 1, 8, 8, 2).cuda().double()

                for i in range(8):
                    for j in range(8):
                        for m in range(16):
                            for n in range(16):
                                y[..., i, j, :] += x[..., i + m * 8,
                                                     j + n * 8, :]

                y = y / (16 * 16)

                subsample_fourier = backend.SubsampleFourier()

                z = subsample_fourier(x, k=16)
                assert (y - z).abs().max() < 1e-8
                if backend.NAME == 'torch':
                    z = subsample_fourier(x.cpu(), k=16)
                    assert (y.cpu() - z).abs().max() < 1e-8
        elif device == 'cpu':
            for backend in backends:
                if backend.NAME == 'skcuda':
                    continue
                x = torch.rand(100, 1, 128, 128, 2).double()
                y = torch.zeros(100, 1, 8, 8, 2).double()

                for i in range(8):
                    for j in range(8):
                        for m in range(16):
                            for n in range(16):
                                y[..., i, j, :] += x[..., i + m * 8,
                                                     j + n * 8, :]

                y = y / (16 * 16)

                subsample_fourier = backend.SubsampleFourier()

                z = subsample_fourier(x, k=16)
                assert (y - z).abs().max() < 1e-8
                if backend.NAME == 'torch':
                    z = subsample_fourier(x.cpu(), k=16)
                    assert (y.cpu() - z).abs().max() < 1e-8
        else:
            raise ('No backend or device detected.')
示例#2
0
repeats = 3

x = torch.randn(B, 1, N, N, 2)
y = torch.randn(N, N, 2)

x = x.cuda()
y = y.cuda()

x = x.contiguous()
y = y.contiguous()

subsample_factors = (8, 16, 32, 64)

cdgmm = backend.cdgmm
modulus = backend.Modulus()
subsample = backend.SubsampleFourier()

name_list = ['cdgmm', 'modulus']
func_list = [cdgmm, modulus]
args_list = [(x, y), (x, )]
params_list = [(), ()]

for subsample_factor in subsample_factors:
    name_list.append('subsample')
    func_list.append(subsample)
    args_list.append((x, ))
    params_list.append((subsample_factor, ))

for name, func, args, params in zip(name_list, func_list, args_list,
                                    params_list):
    out = func(*(args + params))