示例#1
0
文件: test_ifft.py 项目: sony/nnabla
def test_fft_forward_backward(seed, ctx, func_name, batch_dims, signal_ndim,
                              dims, normalized):

    if func_name == "IFFT":
        pytest.skip("Not implemented in CPU.")

    from nbla_test_utils import function_tester, convert_to_float2_array, convert_to_complex_array
    rng = np.random.RandomState(seed)
    shape = batch_dims + dims
    x_data_complex = rng.rand(*shape) + 1j * rng.rand(*shape)
    x_data = convert_to_float2_array(x_data_complex)
    inputs = [x_data]
    func_args = [signal_ndim, normalized]
    function_tester(rng,
                    F.ifft,
                    ref_ifft,
                    inputs,
                    func_args=func_args,
                    atol_f=1e-4,
                    atol_b=1e-4,
                    backward=[True],
                    ctx=ctx,
                    func_name=func_name,
                    ref_grad=ref_grad_ifft,
                    disable_half_test=True)
示例#2
0
def test_fft_double_backward(seed, ctx, func_name, batch_dims,
                             signal_ndim, dims, normalized):

    if func_name == "IFFTCuda" and sys.platform == 'win32':
        from nnabla_ext import cuda
        if cuda._version.__cuda_version__ == '11.4':
            pytest.skip("Skip win32+CUDA114 tests")

    if func_name == "IFFT":
        pytest.skip("Not implemented in CPU.")

    from nbla_test_utils import backward_function_tester, convert_to_float2_array, convert_to_complex_array
    rng = np.random.RandomState(seed)
    shape = batch_dims + dims
    x_data_complex = rng.rand(*shape) + 1j * rng.rand(*shape)
    x_data = convert_to_float2_array(x_data_complex)
    inputs = [x_data]
    func_args = [signal_ndim, normalized]
    backward_function_tester(rng,
                             F.ifft,
                             inputs,
                             func_args=func_args,
                             atol_f=1e-4,
                             atol_accum=5e-2,
                             backward=[True],
                             ctx=ctx)
示例#3
0
def ref_fft(x, signal_ndim, normalized):
    from nbla_test_utils import convert_to_float2_array, convert_to_complex_array

    x_data_complex = convert_to_complex_array(x)
    batch_dims = x_data_complex.shape[0:len(
        x_data_complex.shape) - signal_ndim]
    ref_data_complex = np.fft.fftn(x_data_complex,
                                   axes=np.arange(signal_ndim) +
                                   len(batch_dims),
                                   norm="ortho" if normalized else None)
    ref_data_float2 = convert_to_float2_array(
        ref_data_complex).astype(np.float32)
    return ref_data_float2
示例#4
0
def ref_grad_fft(x, dy, signal_ndim, normalized):
    from nbla_test_utils import convert_to_float2_array, convert_to_complex_array

    dy_complex = convert_to_complex_array(dy)
    batch_dims = dy_complex.shape[0:len(dy_complex.shape) - signal_ndim]
    ref_grad_complex = np.fft.ifftn(dy_complex,
                                    axes=np.arange(signal_ndim) +
                                    len(batch_dims),
                                    norm="ortho" if normalized else None)
    if not normalized:
        scale = np.prod(ref_grad_complex.shape[len(batch_dims):]) if len(
            batch_dims) > 0 else np.prod(ref_grad_complex.shape)
        ref_grad_complex *= scale
    ref_grad = convert_to_float2_array(ref_grad_complex).astype(np.float32)
    return ref_grad.flatten()
示例#5
0
def test_fft_double_backward(seed, ctx, func_name, batch_dims,
                             signal_ndim, dims, normalized):
    if func_name == "FFT":
        pytest.skip("Not implemented in CPU.")

    from nbla_test_utils import backward_function_tester, convert_to_float2_array, convert_to_complex_array
    rng = np.random.RandomState(seed)
    shape = batch_dims + dims
    x_data_complex = rng.rand(*shape) + 1j * rng.rand(*shape)
    x_data = convert_to_float2_array(x_data_complex)
    inputs = [x_data]
    func_args = [signal_ndim, normalized]
    backward_function_tester(rng,
                             F.fft,
                             inputs,
                             func_args=func_args,
                             atol_f=1e-3,
                             atol_accum=8e-2,
                             backward=[True],
                             ctx=ctx)