Beispiel #1
0
def test_Modulus():
    for device in devices:
        if device == 'gpu':
            for backend in backends:
                modulus = backend.Modulus()
                x = torch.rand(100, 10, 4, 2).cuda().float()
                y = modulus(x)
                u = torch.squeeze(torch.sqrt(torch.sum(x * x, 3)))
                v = y.narrow(3, 0, 1)
                u = u.squeeze()
                v = v.squeeze()
                assert (u - v).abs().max() < 1e-6
        elif device == 'cpu':
            for backend in backends:
                if backend.NAME == 'skcuda':
                    continue
                modulus = backend.Modulus()
                x = torch.rand(100, 10, 4, 2).float()
                y = modulus(x)
                u = torch.squeeze(torch.sqrt(torch.sum(x * x, 3)))
                v = y.narrow(3, 0, 1)
                u = u.squeeze()
                v = v.squeeze()
                assert (u - v).abs().max() < 1e-6
        else:
            raise ('No backend or device detected.')
trials = 10
repeats = 3

x = torch.randn(B, 1, N, N, 2)
y = torch.randn(N, N, 2)

x = x.cuda()
y = y.cuda()

x = x.contiguous()
y = y.contiguous()

subsample_factors = (8, 16, 32, 64)

cdgmm = backend.cdgmm
modulus = backend.Modulus()
subsample = backend.SubsampleFourier()

name_list = ['cdgmm', 'modulus']
func_list = [cdgmm, modulus]
args_list = [(x, y), (x, )]
params_list = [(), ()]

for subsample_factor in subsample_factors:
    name_list.append('subsample')
    func_list.append(subsample)
    args_list.append((x, ))
    params_list.append((subsample_factor, ))

for name, func, args, params in zip(name_list, func_list, args_list,
                                    params_list):