def test_FFT3d_central_freq_batch(device, backend):
    # Checked the 0 frequency for the 3D FFT
    for device in devices:
        x = torch.zeros(1, 32, 32, 32, 2).float()
        if device == 'gpu':
            x = x.cuda()
        a = x.sum()
        y = backend.fft(x)
        c = y[:, 0, 0, 0].sum()
        assert (c - a).abs().sum() < 1e-6
def test_fft3d_error(backend, device):
    x = torch.zeros(8, 1)
    with pytest.raises(TypeError) as record:
        backend.fft(x)
    assert "should be complex" in record.value.args[0]