示例#1
0
def test_dft_3d():
    runtime = get_runtime()
    input_data = build_fft_input_data()
    input_tensor = ng.constant(input_data)
    input_axes = ng.constant(np.array([0, 1, 2], dtype=np.int64))

    dft_node = ng.dft(input_tensor, input_axes)
    computation = runtime.computation(dft_node)
    dft_results = computation()
    np_results = np.fft.fftn(np.squeeze(input_data.view(dtype=np.complex64), axis=-1), axes=[0, 1, 2])
    expected_results = np_results.view(dtype=np.float32).reshape((2, 10, 10, 2))
    assert np.allclose(dft_results, expected_results, atol=0.0002)
示例#2
0
def test_dft_2d_signal_size_2():
    runtime = get_runtime()
    input_data = build_fft_input_data()
    input_tensor = ng.constant(input_data)
    input_axes = ng.constant(np.array([1, 2], dtype=np.int64))
    input_signal_size = ng.constant(np.array([4, 5], dtype=np.int64))

    dft_node = ng.dft(input_tensor, input_axes, input_signal_size)
    computation = runtime.computation(dft_node)
    dft_results = computation()
    np_results = np.fft.fft2(np.squeeze(input_data.view(dtype=np.complex64), axis=-1), s=[4, 5], axes=[1, 2])
    expected_results = np_results.view(dtype=np.float32).reshape((2, 4, 5, 2))
    assert np.allclose(dft_results, expected_results, atol=0.000062)
示例#3
0
def test_dft_2d():
    runtime = get_runtime()
    input_data = build_fft_input_data()
    input_tensor = ng.constant(input_data)
    input_axes = ng.constant(np.array([1, 2], dtype=np.int64))

    dft_node = ng.dft(input_tensor, input_axes)
    computation = runtime.computation(dft_node)
    dft_results = computation()
    np_results = np.fft.fft2(np.squeeze(input_data.view(dtype=np.complex64), axis=-1),
                             axes=[1, 2]).astype(np.complex64)
    expected_results = np.stack((np_results.real, np_results.imag), axis=-1)
    assert np.allclose(dft_results, expected_results, atol=0.000062)